#include "MapperLitVolume.h"
#include "LitVolumeRenderer.h"
#include "compositing/Compositor.h"
#include "mpi/MpiEnv.h"
#include "utils/Fmt.h"

#include <vtkm/cont/EnvironmentTracker.h>
#include <vtkm/cont/Timer.h>
#include <vtkm/cont/TryExecute.h>
#include <vtkm/rendering/CanvasRayTracer.h>
#include <vtkm/rendering/raytracing/Camera.h>
#include <vtkm/rendering/raytracing/Logger.h>
#include <vtkm/rendering/raytracing/RayOperations.h>
#include <vtkm/thirdparty/diy/diy.h>
#include <vtkm/thirdparty/diy/mpi-cast.h>

#include <mpi.h>
#include <sstream>

#define DEFAULT_SAMPLE_DISTANCE -1.f

extern std::string P5CompositeLabel;

namespace beams
{
namespace rendering
{
struct MapperLitVolume::InternalsType
{
  beams::rendering::LitVolumeRenderer Tracer;
  vtkm::rendering::compositing::Compositor Compositor;
  VTKM_CONT
  InternalsType() {}
};

MapperLitVolume::MapperLitVolume()
  : Internals(new InternalsType)
{
}

MapperLitVolume::~MapperLitVolume() {}

void MapperLitVolume::AddLight(std::shared_ptr<Light> light)
{
  this->Internals->Tracer.AddLight(light);
}

void MapperLitVolume::ClearLights()
{
  this->Internals->Tracer.ClearLights();
}

void MapperLitVolume::SetDensityScale(vtkm::Float32 densityScale)
{
  this->Internals->Tracer.DensityScale = densityScale;
}

void MapperLitVolume::SetShadowMapSize(vtkm::Id3 size)
{
  this->Internals->Tracer.SetShadowMapSize(size);
}

void MapperLitVolume::SetShadowMapNumSteps(vtkm::Id numSteps)
{
  this->Internals->Tracer.SetShadowMapNumSteps(numSteps);
}

void MapperLitVolume::SetNumShadowSamples(vtkm::Id numSamples)
{
  this->Internals->Tracer.SetNumShadowSamples(numSamples);
}

void MapperLitVolume::RenderCells(const vtkm::cont::UnknownCellSet& cellset,
                                  const vtkm::cont::CoordinateSystem& coords,
                                  const vtkm::cont::Field& scalarField,
                                  const vtkm::cont::ColorTable& vtkmNotUsed(colorTable),
                                  const vtkm::rendering::Camera& camera,
                                  const vtkm::Range& scalarRange)
{
  this->Internals->Tracer.ProfilerTimes.clear();
  if (!cellset.IsType<vtkm::cont::CellSetStructured<3>>())
  {
    std::stringstream msg;
    std::string theType = typeid(cellset).name();
    msg << "Mapper volume: cell set type not currently supported\n";
    msg << "Type : " << theType << std::endl;
    throw vtkm::cont::ErrorBadValue(msg.str());
  }
  vtkm::rendering::raytracing::Logger* logger = vtkm::rendering::raytracing::Logger::GetInstance();
  logger->OpenLogEntry("mapper_volume");
  vtkm::cont::Timer tot_timer;
  tot_timer.Start();
  vtkm::cont::Timer timer;

  auto& tracer = this->Internals->Tracer;
  tracer.Canvas = this->Canvas;

  vtkm::rendering::raytracing::Camera rayCamera;
  vtkm::rendering::raytracing::Ray<vtkm::Float32> rays;

  vtkm::Int32 width = (vtkm::Int32)this->Canvas->GetWidth();
  vtkm::Int32 height = (vtkm::Int32)this->Canvas->GetHeight();

  rayCamera.SetParameters(camera, width, height);

  rayCamera.CreateRays(rays, coords.GetBounds());
  rays.Buffers.at(0).InitConst(0.f);
  // rays.Distance.AllocateAndFill(rays.NumRays, std::numeric_limits<vtkm::Float32>::infinity());
  vtkm::rendering::raytracing::RayOperations::MapCanvasToRays(rays, camera, *this->Canvas);


  if (this->SampleDistance != DEFAULT_SAMPLE_DISTANCE)
  {
    tracer.SetSampleDistance(this->SampleDistance);
  }

  tracer.SetBoundsMap(this->BoundsMap);
  tracer.SetDensityCorrectionRatio(this->DensityCorrectionRatio);
  tracer.SetData(
    coords, scalarField, cellset.AsCellSet<vtkm::cont::CellSetStructured<3>>(), scalarRange);
  tracer.SetColorMap(this->ColorMap);
  tracer.Render(rays);

  vtkm::cont::Timer phase5CompositeTimer;
  phase5CompositeTimer.Start();
  this->Canvas->WriteToCanvas(rays, rays.Buffers.at(0).Buffer, camera);

  const int COMPOSITE_CUSTOM = 1;
  const int COMPOSITE_VTKH = 2;

  const int COMPOSITER = 1;

  auto mpi = beams::mpi::MpiEnv::GetInstance();
  if (mpi->Size > 1)
  {
    if (COMPOSITER == COMPOSITE_CUSTOM)
    {
      this->GlobalComposite(camera);
    }
    else if (COMPOSITER == COMPOSITE_VTKH)
    {
      this->Composite(camera);
    }
  }

  if (this->CompositeBackground)
  {
    this->Canvas->BlendBackground();
  }

  phase5CompositeTimer.Stop();
  auto& profilerTimes = this->Internals->Tracer.GetProfilerTimes();
  profilerTimes.push_back({
    .Name = P5CompositeLabel,
    .Time = phase5CompositeTimer.GetElapsedTime(),
    .RootOnly = true,
  });
}

struct DepthOrderOld
{
  int Rank;
  vtkm::Float32 Depth;
};

template <typename Portal, typename T>
void CopyIntoVec(const Portal& portal, std::vector<T>& vec)
{
  for (int i = 0; i < portal.GetNumberOfValues(); ++i)
  {
    vec.push_back(portal.Get(i));
  }
}

bool IsTransparentColor(const vtkm::Vec4f_32 color)
{
  return color[3] == 0.0f;
}

void MapperLitVolume::GlobalComposite(const vtkm::rendering::Camera& camera)
{
  auto mpi = beams::mpi::MpiEnv::GetInstance();
  const auto& comm = mpi->Comm;
  MPI_Comm mpiComm = vtkmdiy::mpi::mpi_cast(comm.handle());

  std::vector<DepthOrderOld> depthOrders;
  auto cameraPos = camera.GetPosition();
  for (int i = 0; i < mpi->Size; ++i)
  {
    auto& bounds = this->BoundsMap->BlockBounds[i];
    auto center = bounds.Center();
    vtkm::Vec3f_32 center_32{ vtkm::Float32(center[0]),
                              vtkm::Float32(center[1]),
                              vtkm::Float32(center[2]) };
    vtkm::Float32 depth = vtkm::Magnitude(cameraPos - center_32);
    depthOrders.push_back({ .Rank = i, .Depth = depth });
  }

  std::sort(depthOrders.begin(),
            depthOrders.end(),
            [](const DepthOrderOld& l, const DepthOrderOld& r) -> bool
            { return l.Depth < r.Depth; });
  std::vector<int> depthRanks;
  for (auto& depth : depthOrders)
  {
    depthRanks.push_back(depth.Rank);
  }

  auto colorBuffer = this->Canvas->GetColorBuffer();
  int numPixels = colorBuffer.GetNumberOfValues();
  int totalNumPixels = mpi->Size * numPixels;

  std::vector<vtkm::Vec4f_32> gatheredColors;

  MPI_Datatype MPI_VEC4F32;
  MPI_Type_contiguous(4, MPI_FLOAT, &MPI_VEC4F32);
  MPI_Type_commit(&MPI_VEC4F32);

  std::vector<vtkm::Vec4f_32> colorVals;
  CopyIntoVec(colorBuffer.ReadPortal(), colorVals);
  if (mpi->Rank == 0)
  {
    gatheredColors.resize(totalNumPixels);
    MPI_Gather(colorVals.data(),
               numPixels,
               MPI_VEC4F32,
               gatheredColors.data(),
               numPixels,
               MPI_VEC4F32,
               0,
               mpiComm);
  }
  else
  {
    MPI_Gather(colorVals.data(), numPixels, MPI_VEC4F32, nullptr, 0, MPI_VEC4F32, 0, mpiComm);
  }

  if (mpi->Rank != 0)
    return;

  std::vector<vtkm::Vec4f_32> finalBuffer;
  finalBuffer.resize(numPixels);
  int offsetStart = depthRanks[0] * numPixels;
  int offsetEnd = offsetStart + numPixels;
  for (int i = offsetStart; i < offsetEnd; ++i)
  {
    finalBuffer[i - offsetStart] = gatheredColors[i];
  }

  for (std::size_t i = 1; i < depthRanks.size(); ++i)
  {
    int offset = depthRanks[i] * numPixels;
    for (int j = 0; j < numPixels; ++j)
    {
      auto a = finalBuffer[j];
      auto b = gatheredColors[j + offset];

      a[0] = a[0] + b[0] * (1.0f - a[3]);
      a[1] = a[1] + b[1] * (1.0f - a[3]);
      a[2] = a[2] + b[2] * (1.0f - a[3]);
      a[3] = a[3] + b[3] * (1.0f - a[3]);

      finalBuffer[j] = a;
    }
  }

  vtkm::cont::Algorithm::Copy(vtkm::cont::make_ArrayHandle(finalBuffer, vtkm::CopyFlag::Off),
                              colorBuffer);
}

std::vector<int> MapperLitVolume::FindVisibilityOrdering(const vtkm::rendering::Camera& camera)
{
  auto mpi = beams::mpi::MpiEnv::GetInstance();
  auto comm = mpi->Comm;
  std::vector<int> visibilityOrders(mpi->Size);

  std::vector<float> minDepths;
  minDepths.resize(mpi->Size);

  for (int dom = 0; dom < mpi->Size; ++dom)
  {
    auto bounds = this->BoundsMap->BlockBounds[dom];
    minDepths[dom] = this->FindMinDepth(camera, bounds);
  }

  this->DepthSort(mpi->Size, minDepths, visibilityOrders);
  return visibilityOrders;
}

struct VisOrdering
{
  int m_rank;
  int m_domain_index;
  int m_order;
  vtkm::Float32 m_minz;
};

struct DepthOrder
{
  inline bool operator()(const VisOrdering& lhs, const VisOrdering& rhs)
  {
    return lhs.m_minz < rhs.m_minz;
  }
};

struct RankOrder
{
  inline bool operator()(const VisOrdering& lhs, const VisOrdering& rhs)
  {
    if (lhs.m_rank < rhs.m_rank)
    {
      return true;
    }
    else if (lhs.m_rank == rhs.m_rank)
    {
      return lhs.m_domain_index < rhs.m_domain_index;
    }
    return false;
  }
};

void MapperLitVolume::DepthSort(int size,
                                std::vector<vtkm::Float32>& minDepths,
                                std::vector<int>& visOrderings)
{
  std::vector<VisOrdering> order(size);

  for (int i = 0; i < size; ++i)
  {
    order[i].m_rank = 0;
    order[i].m_domain_index = i;
    order[i].m_minz = minDepths[i];
  }
  std::sort(order.begin(), order.end(), DepthOrder());

  for (int i = 0; i < size; ++i)
  {
    order[i].m_order = i;
  }

  std::sort(order.begin(), order.end(), RankOrder());

  for (int i = 0; i < size; ++i)
  {
    visOrderings[i] = order[i].m_order;
  }
}

vtkm::Float32 MapperLitVolume::FindMinDepth(const vtkm::rendering::Camera& camera,
                                            const vtkm::Bounds& bounds)
{

  vtkm::Vec<vtkm::Float64, 3> center = bounds.Center();
  vtkm::Vec<vtkm::Float64, 3> fcenter;
  fcenter[0] = static_cast<vtkm::Float32>(center[0]);
  fcenter[1] = static_cast<vtkm::Float32>(center[1]);
  fcenter[2] = static_cast<vtkm::Float32>(center[2]);
  vtkm::Vec<vtkm::Float32, 3> pos = camera.GetPosition();
  vtkm::Float32 dist = vtkm::Magnitude(fcenter - pos);
  return dist;
}

void MapperLitVolume::Composite(const vtkm::rendering::Camera& camera)
{
  this->Internals->Compositor.SetCompositeMode(
    vtkm::rendering::compositing::Compositor::VIS_ORDER_BLEND);
  std::vector<int> visibilityOrdering = this->FindVisibilityOrdering(camera);
  FMT_VEC0(visibilityOrdering);
  auto mpi = beams::mpi::MpiEnv::GetInstance();
  this->Internals->Compositor.AddImage(*(this->Canvas), mpi->Rank);
  auto result = this->Internals->Compositor.Composite();
  //Rank 0 has the composited result, so put it into the Canvas.
  auto comm = mpi->Comm;
  if (comm.rank() == 0)
  {
    this->Canvas->CopyFrom(vtkm::cont::make_ArrayHandle(result.Pixels, vtkm::CopyFlag::On),
                           vtkm::cont::make_ArrayHandle(result.Depths, vtkm::CopyFlag::On));
  }
}

vtkm::rendering::Mapper* MapperLitVolume::NewCopy() const
{
  return new beams::rendering::MapperLitVolume(*this);
}

std::vector<beams::profiling::Record> MapperLitVolume::GetProfilerTimes()
{
  return this->Internals->Tracer.GetProfilerTimes();
}

void MapperLitVolume::SetUseClamp(bool useClamp)
{
  this->Internals->Tracer.SetUseClamp(useClamp);
}

void MapperLitVolume::SetUseReinhard(bool useReinhard)
{
  this->Internals->Tracer.SetUseReinhard(useReinhard);
}
} // namespace rendering
} // namespace beams
