#include "Timings.h"
#include "Fmt.h"

#include <ctime>

namespace beams
{
namespace utils
{
namespace
{
int ToMs(vtkm::Float64 t)
{
  return static_cast<int>(t * 1000);
}

std::vector<vtkm::Float64> RemoveOutliers(std::vector<vtkm::Float64>& timings)
{
  if (timings.size() < 4)
  {
    return timings;
  }

  std::sort(timings.begin(), timings.end());

  int q1 = static_cast<int>(timings.size()) / 4;
  int q3 = 3 * static_cast<int>(timings.size()) / 4;
  vtkm::Float64 iqr = timings[q3] - timings[q1];

  vtkm::Float64 lowerBound = timings[q1] - 1.5 * iqr;
  vtkm::Float64 upperBound = timings[q3] + 1.5 * iqr;

  std::vector<vtkm::Float64> filteredTimings;
  for (vtkm::Float64 timing : timings)
  {
    if (timing >= lowerBound && timing <= upperBound)
    {
      filteredTimings.push_back(timing);
    }
  }

  return filteredTimings;
}

std::string GetISOTimestamp()
{
  time_t now;
  time(&now);
  char buff[sizeof "2011-10-08T07:07:09Z"];
  strftime(buff, sizeof buff, "%F-%T", gmtime(&now));
  return std::string(buff);
}
}

Timings::Timings(const std::string& sceneId, const std::vector<std::string>& labels)
  : SceneId(sceneId)
  , Labels(labels)
{
}

void Timings::AddIteration(const std::vector<beams::profiling::Record>& times)
{
  for (const auto& time : times)
  {
    bool knownLabel =
      std::find(this->Labels.begin(), this->Labels.end(), time.Name) != this->Labels.end();
    if (!knownLabel)
    {
      // Fmt::Println0("Ignoring unknown label: {}", time.Name);
      continue;
    }

    std::vector<vtkm::Float64> collectedTimes;
    if (time.RootOnly)
    {
      collectedTimes.push_back(time.Time);
    }
    else
    {
      collectedTimes = Timings::CollectDistributedTimeStats(time.Time);
    }
    collectedTimes = RemoveOutliers(collectedTimes);
    auto& timings = this->DistributedTimings[time.Name];
    timings.insert(timings.end(), collectedTimes.begin(), collectedTimes.end());
  }
}

void Timings::Save(const std::string& fileName)
{
  std::ofstream file;
  file.open(fileName, std::ios::app);
  bool hasHeader = file.tellp() != 0;
  if (!hasHeader)
  {
    file << "Time";
    file << ",";
    file << "Id";
    file << ",";
    file << "World";
    for (const auto& label : this->Labels)
    {
      file << "," << ("Min-" + label) << "," << ("Max-" + label) << "," << ("Avg-" + label);
    }
    file << ",";
    file << "Avg-Total";
    file << std::endl;
  }

  vtkm::Float64 avgTotal = 0.0f;
  auto mpi = beams::mpi::MpiEnv::GetInstance();
  file << GetISOTimestamp();
  file << ",";
  file << this->SceneId;
  file << ",";
  file << mpi->Size;
  for (const auto& label : this->Labels)
  {
    auto collectedTimes = this->DistributedTimings[label];
    collectedTimes = RemoveOutliers(collectedTimes);

    vtkm::Float64 minTime = *std::min_element(collectedTimes.begin(), collectedTimes.end());
    vtkm::Float64 maxTime = *std::max_element(collectedTimes.begin(), collectedTimes.end());
    vtkm::Float64 avgTime = std::accumulate(collectedTimes.begin(), collectedTimes.end(), 0.0) /
      static_cast<vtkm::Float64>(collectedTimes.size());
    avgTotal += avgTime;
    file << "," << ToMs(minTime) << "," << ToMs(maxTime) << "," << ToMs(avgTime);
  }

  file << "," << ToMs(avgTotal);
  file << std::endl;

  file.close();
}

std::vector<vtkm::Float64> Timings::CollectDistributedTimeStats(vtkm::Float64 time)
{
  auto mpi = beams::mpi::MpiEnv::GetInstance();
  auto comm = mpi->Comm;
  std::vector<vtkm::Float64> allTimes(comm.size(), 0.0);
  if (comm.rank() == 0)
  {
    vtkmdiy::mpi::gather(comm, time, allTimes, 0);
  }
  else
  {
    vtkmdiy::mpi::gather(comm, time, 0);
  }
  return allTimes;
}
}
}