#include "Profiler.h"
#include "mpi/MpiEnv.h"
#include "utils/Fmt.h"

#include <vtkm/thirdparty/diy/diy.h>
#include <vtkm/thirdparty/diy/mpi-cast.h>

namespace
{
int ToMs(vtkm::Float64 t)
{
  return static_cast<int>(t * 1000);
}
} // namespace

namespace beams
{
ProfilerFrame::ProfilerFrame(const std::string& name, const beams::mpi::MpiEnv& mpi)
  : Name(name)
  , Timer(name)
  , Mpi(mpi)
{
}

void ProfilerFrame::Collect()
{
  const auto& comm = this->Mpi.Comm;
  MPI_Comm mpiComm = vtkmdiy::mpi::mpi_cast(comm.handle());
  this->ElapsedTimes.resize(this->Mpi.Size);
  vtkm::Float64 rankTime = this->Timer.GetElapsedTime();
  MPI_Gather(&rankTime, 1, MPI_DOUBLE, this->ElapsedTimes.data(), 1, MPI_DOUBLE, 0, mpiComm);
  if (this->Mpi.Rank == 0)
  {
    auto minMaxTime = std::minmax_element(this->ElapsedTimes.begin(), this->ElapsedTimes.end());
    this->MinTime = *(minMaxTime.first);
    this->MaxTime = *(minMaxTime.second);
    this->AvgTime = std::accumulate(this->ElapsedTimes.begin(), this->ElapsedTimes.end(), 0.0) /
      static_cast<vtkm::Float64>(this->ElapsedTimes.size());
  }
}

void ProfilerFrame::PrintSummary(std::ostream& stream, int level) const
{
  std::string spaces(level * 2, ' ');
  std::string summary = fmt::format("{}{}: Min = {} ms, Max = {} ms, Avg = {} ms\n",
                                    spaces,
                                    this->Name,
                                    ToMs(this->MinTime),
                                    ToMs(this->MaxTime),
                                    ToMs(this->AvgTime));
  stream << summary;
}

Profiler::Profiler(const std::string& name, const beams::mpi::MpiEnv& mpi)
  : Mpi(mpi)
{
  this->StartFrame(name);
}

void Profiler::StartFrame(const std::string& name)
{
  std::shared_ptr<ProfilerFrame> frame = std::make_shared<ProfilerFrame>(name, this->Mpi);
  if (!this->Empty())
  {
    auto topFrame = this->Peek();
    topFrame->SubFrames.push_back(frame);
    this->Push(frame);
  }
  else
  {
    this->RootFrame = frame;
    this->Push(frame);
  }
  frame->Timer.Start();
}

void Profiler::EndFrame()
{
  auto topFrame = this->Pop();
  topFrame->Timer.Stop();
}

void Profiler::Collect()
{
  this->CollectInternal(this->RootFrame);
}

void Profiler::CollectInternal(std::shared_ptr<ProfilerFrame> frame)
{
  frame->Collect();
  for (auto subFrame : frame->SubFrames)
  {
    this->CollectInternal(subFrame);
  }
}

void Profiler::Push(std::shared_ptr<ProfilerFrame> frame)
{
  this->FrameStack.push(frame);
}

std::shared_ptr<ProfilerFrame> Profiler::Peek()
{
  return this->FrameStack.top();
}

std::shared_ptr<ProfilerFrame> Profiler::Pop()
{
  auto topFrame = this->Peek();
  this->FrameStack.pop();
  return topFrame;
}

bool Profiler::Empty() const
{
  return this->FrameStack.empty();
}

void Profiler::PrintSummary(std::ostream& stream)
{
  int level = 0;
  this->PrintSummaryInternal(this->RootFrame, stream, level);
}

void Profiler::PrintSummaryInternal(std::shared_ptr<ProfilerFrame> frame,
                                    std::ostream& stream,
                                    int level)
{
  frame->PrintSummary(stream, level);
  for (auto& subFrame : frame->SubFrames)
  {
    this->PrintSummaryInternal(subFrame, stream, level + 1);
  }
}
} // namespace beams