#include "LitVolumeRenderer.h"
#include "../Math.h"
#include "CellLocators.h"
#include "PointLight.h"
#include "ShadowVolume.h"
#include "mpi/Types.h"
#include "utils/Fmt.h"

#include <vtkm/cont/Algorithm.h>
#include <vtkm/cont/ArrayHandleCartesianProduct.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayHandleUniformPointCoordinates.h>
#include <vtkm/cont/CellSetStructured.h>
#include <vtkm/cont/ColorTable.h>
#include <vtkm/cont/ErrorBadValue.h>
#include <vtkm/cont/Timer.h>
#include <vtkm/cont/TryExecute.h>
#include <vtkm/io/VTKDataSetWriter.h>
#include <vtkm/rendering/raytracing/Camera.h>
#include <vtkm/rendering/raytracing/Logger.h>
#include <vtkm/rendering/raytracing/Ray.h>
#include <vtkm/rendering/raytracing/RayTracingTypeDefs.h>
#include <vtkm/thirdparty/diy/diy.h>
#include <vtkm/thirdparty/diy/mpi-cast.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/WorkletMapField.h>

#include <mpi.h>

#include <math.h>
#include <numeric>
#include <stdio.h>

extern std::string P1ShadowVolumeGenLabel;
extern std::string P2ShadowVolumeRenderLabel;
extern std::string P3ShadowVolumeUpdateLabel;
extern std::string P4VolumeRenderLabel;

namespace beams
{
namespace rendering
{
namespace
{
template <typename DeviceAdapterTag, typename LocatorType, typename MapEstimatorType>
class Sampler : public vtkm::worklet::WorkletMapField
{
private:
  using ColorArrayHandle = typename vtkm::cont::ArrayHandle<vtkm::Vec4f_32>;
  using ColorArrayPortal = typename ColorArrayHandle::ReadPortalType;
  int Rank;
  ColorArrayPortal ColorMap;
  vtkm::Id ColorMapSize;
  vtkm::Float32 MinScalar;
  vtkm::Float32 SampleDistance;
  vtkm::Float32 InverseDeltaScalar;
  LocatorType Locator;
  vtkm::Float32 MeshEpsilon;
  MapEstimatorType MapEstimator;
  vtkm::Id NumShadowSamples;
  vtkm::Float32 DensityCorrectionRatio;
  bool UseClamp;
  bool UseReinhard;

public:
  VTKM_CONT
  Sampler(const ColorArrayHandle& colorMap,
          const vtkm::Float32& minScalar,
          const vtkm::Float32& maxScalar,
          const vtkm::Float32& sampleDistance,
          const LocatorType& locator,
          const vtkm::Float32& meshEpsilon,
          const MapEstimatorType& shadowMapEstimator,
          vtkm::Id numShadowSamples,
          vtkm::Float32 densityCorrectionRatio,
          bool useClamp,
          bool useReinhard,
          vtkm::cont::Token& token)
    : ColorMap(colorMap.PrepareForInput(DeviceAdapterTag(), token))
    , MinScalar(minScalar)
    , SampleDistance(sampleDistance)
    , InverseDeltaScalar(minScalar)
    , Locator(locator)
    , MeshEpsilon(meshEpsilon)
    , MapEstimator(shadowMapEstimator)
    , NumShadowSamples(numShadowSamples)
    , DensityCorrectionRatio(densityCorrectionRatio)
    , UseClamp(useClamp)
    , UseReinhard(useReinhard)
  {
    ColorMapSize = colorMap.GetNumberOfValues() - 1;
    if ((maxScalar - minScalar) != 0.f)
    {
      InverseDeltaScalar = 1.f / (maxScalar - minScalar);
    }
  }

  using ControlSignature =
    void(FieldIn, FieldIn, FieldIn, FieldIn, FieldIn, WholeArrayInOut, WholeArrayIn, FieldInOut);
  using ExecutionSignature = void(_1, _2, _3, _4, _5, _6, _7, _8, WorkIndex);

  template <typename Precision>
  VTKM_EXEC_CONT static bool ApproxEquals(Precision x, Precision y, Precision eps = 1e-5f)
  {
    return vtkm::Abs(x - y) <= eps;
  }
  template <typename ScalarPortalType, typename ColorBufferType>
  VTKM_EXEC void operator()(const vtkm::Id& rayId,
                            const vtkm::Vec3f_32& rayDir,
                            const vtkm::Vec3f_32& rayOrigin,
                            const vtkm::Float32& minDistance,
                            const vtkm::Float32& maxDistance,
                            ColorBufferType& colorBuffer,
                            ScalarPortalType& scalars,
                            vtkm::Vec<vtkm::UInt32, 2>& seed,
                            const vtkm::Id& pixelIndex) const
  {
    vtkm::Vec4f_32 color;
    color[0] = colorBuffer.Get(pixelIndex * 4 + 0);
    color[1] = colorBuffer.Get(pixelIndex * 4 + 1);
    color[2] = colorBuffer.Get(pixelIndex * 4 + 2);
    color[3] = colorBuffer.Get(pixelIndex * 4 + 3);

    if (minDistance == -1.f)
    {
      return; //TODO: Compact? or just image subset...
    }

    //get the initial sample position;
    vtkm::Vec3f_32 sampleLocation;
    // find the distance to the first sample
    vtkm::Float32 distance = minDistance + MeshEpsilon;
    sampleLocation = rayOrigin + distance * rayDir;
    // since the calculations are slightly different, we could hit an
    // edge case where the first sample location may not be in the data set.
    // Thus, advance to the next sample location
    while (!Locator.IsInside(sampleLocation) && distance < maxDistance)
    {
      distance += SampleDistance;
      sampleLocation = rayOrigin + distance * rayDir;
    }
    /*
            7----------6
           /|         /|
          4----------5 |
          | |        | |
          | 3--------|-2    z y
          |/         |/     |/
          0----------1      |__ x
    */
    vtkm::Vec3f_32 bottomLeft(0, 0, 0);
    bool newCell = true;
    //check to see if we left the cell
    vtkm::Float32 tx = 0.f;
    vtkm::Float32 ty = 0.f;
    vtkm::Float32 tz = 0.f;
    vtkm::Float32 scalar0 = 0.f;
    vtkm::Float32 scalar1minus0 = 0.f;
    vtkm::Float32 scalar2minus3 = 0.f;
    vtkm::Float32 scalar3 = 0.f;
    vtkm::Float32 scalar4 = 0.f;
    vtkm::Float32 scalar5minus4 = 0.f;
    vtkm::Float32 scalar6minus7 = 0.f;
    vtkm::Float32 scalar7 = 0.f;

    vtkm::Id3 cell(0, 0, 0);
    vtkm::Vec3f_32 invSpacing(0.f, 0.f, 0.f);

    bool start = true;
    while (Locator.IsInside(sampleLocation) && distance < maxDistance)
    {
      vtkm::Float32 mint = vtkm::Min(tx, vtkm::Min(ty, tz));
      vtkm::Float32 maxt = vtkm::Max(tx, vtkm::Max(ty, tz));
      if (maxt > 1.f || mint < 0.f)
        newCell = true;

      if (newCell)
      {

        vtkm::Vec<vtkm::Id, 8> cellIndices;
        Locator.LocateCell(cell, sampleLocation, invSpacing);
        Locator.GetCellIndices(cell, cellIndices);
        Locator.GetPoint(cellIndices[0], bottomLeft);

        scalar0 = vtkm::Float32(scalars.Get(cellIndices[0]));
        vtkm::Float32 scalar1 = vtkm::Float32(scalars.Get(cellIndices[1]));
        vtkm::Float32 scalar2 = vtkm::Float32(scalars.Get(cellIndices[2]));
        scalar3 = vtkm::Float32(scalars.Get(cellIndices[3]));
        scalar4 = vtkm::Float32(scalars.Get(cellIndices[4]));
        vtkm::Float32 scalar5 = vtkm::Float32(scalars.Get(cellIndices[5]));
        vtkm::Float32 scalar6 = vtkm::Float32(scalars.Get(cellIndices[6]));
        scalar7 = vtkm::Float32(scalars.Get(cellIndices[7]));

        scalar0 = scalar0;
        scalar1 = scalar1;
        scalar2 = scalar2;
        scalar3 = scalar3;
        scalar4 = scalar4;
        scalar5 = scalar5;
        scalar6 = scalar6;
        scalar7 = scalar7;

        // save ourselves a couple extra instructions
        scalar6minus7 = scalar6 - scalar7;
        scalar5minus4 = scalar5 - scalar4;
        scalar1minus0 = scalar1 - scalar0;
        scalar2minus3 = scalar2 - scalar3;

        tx = (sampleLocation[0] - bottomLeft[0]) * invSpacing[0];
        ty = (sampleLocation[1] - bottomLeft[1]) * invSpacing[1];
        tz = (sampleLocation[2] - bottomLeft[2]) * invSpacing[2];

        newCell = false;
      }

      vtkm::Float32 lerped76 = scalar7 + tx * scalar6minus7;
      vtkm::Float32 lerped45 = scalar4 + tx * scalar5minus4;
      vtkm::Float32 lerpedTop = lerped45 + ty * (lerped76 - lerped45);

      vtkm::Float32 lerped01 = scalar0 + tx * scalar1minus0;
      vtkm::Float32 lerped32 = scalar3 + tx * scalar2minus3;
      vtkm::Float32 lerpedBottom = lerped01 + ty * (lerped32 - lerped01);

      vtkm::Float32 finalScalar = lerpedBottom + tz * (lerpedTop - lerpedBottom);
      //normalize scalar
      finalScalar = (finalScalar - MinScalar) * InverseDeltaScalar;

      vtkm::Id colorIndex =
        static_cast<vtkm::Id>(finalScalar * static_cast<vtkm::Float32>(ColorMapSize));
      if (colorIndex < 0)
        colorIndex = 0;
      else if (colorIndex > ColorMapSize)
        colorIndex = ColorMapSize;

      vtkm::Vec4f_32 sampleColor = ColorMap.Get(colorIndex);

      //apply density correction
      sampleColor[3] = 1.0f - vtkm::Pow(1.0f - sampleColor[3], DensityCorrectionRatio);

      vtkm::Vec3f opacityEstimate{ 1.0f, 1.0f, 1.0f };
      bool useMultiple = (NumShadowSamples > 1) && !start;
      start = false;
      if (!useMultiple)
      {
        opacityEstimate = this->MapEstimator.GetEstimateUsingVertices(sampleLocation);
      }
      else
      {
        opacityEstimate = vtkm::Vec3f{ 0.0f, 0.0f, 0.0f };
        for (vtkm::Id i = 0; i < NumShadowSamples; ++i)
        {
          vtkm::Vec3f delta{ (vtkm::rendering::raytracing::randomf(seed) * 2.0f) - 1.0f,
                             (vtkm::rendering::raytracing::randomf(seed) * 2.0f) - 1.0f,
                             (vtkm::rendering::raytracing::randomf(seed) * 2.0f) - 1.0f };
          auto opacitySampleLocation = sampleLocation + delta * this->MapEstimator.Eps;
          opacityEstimate += this->MapEstimator.GetEstimateUsingVertices(opacitySampleLocation);
        }

        vtkm::Float32 inv = 1.0f / static_cast<vtkm::Float32>(NumShadowSamples);
        opacityEstimate[0] *= inv;
        opacityEstimate[1] *= inv;
        opacityEstimate[2] *= inv;
      }

      sampleColor[0] = sampleColor[0] * opacityEstimate[0];
      sampleColor[1] = sampleColor[1] * opacityEstimate[1];
      sampleColor[2] = sampleColor[2] * opacityEstimate[2];

      sampleColor[0] = vtkm::Clamp(sampleColor[0], 0.0f, 1.0f);
      sampleColor[1] = vtkm::Clamp(sampleColor[1], 0.0f, 1.0f);
      sampleColor[2] = vtkm::Clamp(sampleColor[2], 0.0f, 1.0f);

      //composite
      sampleColor[3] *= (1.f - color[3]);
      color[0] = color[0] + sampleColor[0] * sampleColor[3];
      color[1] = color[1] + sampleColor[1] * sampleColor[3];
      color[2] = color[2] + sampleColor[2] * sampleColor[3];
      color[3] = sampleColor[3] + color[3];

      //advance
      distance += SampleDistance;
      sampleLocation = sampleLocation + SampleDistance * rayDir;

      //this is linear could just do an addition
      tx = (sampleLocation[0] - bottomLeft[0]) * invSpacing[0];
      ty = (sampleLocation[1] - bottomLeft[1]) * invSpacing[1];
      tz = (sampleLocation[2] - bottomLeft[2]) * invSpacing[2];

      if (color[3] >= 1.f)
        break;
    }

    if (this->UseClamp)
    {
      color[0] = vtkm::Clamp(color[0], 0.0f, 1.0f);
      color[1] = vtkm::Clamp(color[1], 0.0f, 1.0f);
      color[2] = vtkm::Clamp(color[2], 0.0f, 1.0f);
      color[3] = vtkm::Clamp(color[3], 0.0f, 1.0f);
    }
    else if (this->UseReinhard)
    {
      color[0] = color[0] / (color[0] + 1.0f);
      color[1] = color[1] / (color[1] + 1.0f);
      color[2] = color[2] / (color[2] + 1.0f);
      color[3] = vtkm::Clamp(color[3], 0.0f, 1.0f);
    }

    colorBuffer.Set(pixelIndex * 4 + 0, color[0]);
    colorBuffer.Set(pixelIndex * 4 + 1, color[1]);
    colorBuffer.Set(pixelIndex * 4 + 2, color[2]);
    colorBuffer.Set(pixelIndex * 4 + 3, color[3]);
  }
}; //Sampler

template <typename DeviceAdapterTag, typename LocatorType>
class SamplerCellAssoc : public vtkm::worklet::WorkletMapField
{
private:
  using ColorArrayHandle = typename vtkm::cont::ArrayHandle<vtkm::Vec4f_32>;
  using ColorArrayPortal = typename ColorArrayHandle::ReadPortalType;
  ColorArrayPortal ColorMap;
  vtkm::Id ColorMapSize;
  vtkm::Float32 MinScalar;
  vtkm::Float32 SampleDistance;
  vtkm::Float32 InverseDeltaScalar;
  LocatorType Locator;
  vtkm::Float32 MeshEpsilon;

public:
  VTKM_CONT
  SamplerCellAssoc(const ColorArrayHandle& colorMap,
                   const vtkm::Float32& minScalar,
                   const vtkm::Float32& maxScalar,
                   const vtkm::Float32& sampleDistance,
                   const LocatorType& locator,
                   const vtkm::Float32& meshEpsilon,
                   vtkm::cont::Token& token)
    : ColorMap(colorMap.PrepareForInput(DeviceAdapterTag(), token))
    , MinScalar(minScalar)
    , SampleDistance(sampleDistance)
    , InverseDeltaScalar(minScalar)
    , Locator(locator)
    , MeshEpsilon(meshEpsilon)
  {
    ColorMapSize = colorMap.GetNumberOfValues() - 1;
    if ((maxScalar - minScalar) != 0.f)
    {
      InverseDeltaScalar = 1.f / (maxScalar - minScalar);
    }
  }
  using ControlSignature = void(FieldIn, FieldIn, FieldIn, FieldIn, WholeArrayInOut, WholeArrayIn);
  using ExecutionSignature = void(_1, _2, _3, _4, _5, _6, WorkIndex);

  template <typename ScalarPortalType, typename ColorBufferType>
  VTKM_EXEC void operator()(const vtkm::Vec3f_32& rayDir,
                            const vtkm::Vec3f_32& rayOrigin,
                            const vtkm::Float32& minDistance,
                            const vtkm::Float32& maxDistance,
                            ColorBufferType& colorBuffer,
                            const ScalarPortalType& scalars,
                            const vtkm::Id& pixelIndex) const
  {
    vtkm::Vec4f_32 color;
    color[0] = colorBuffer.Get(pixelIndex * 4 + 0);
    color[1] = colorBuffer.Get(pixelIndex * 4 + 1);
    color[2] = colorBuffer.Get(pixelIndex * 4 + 2);
    color[3] = colorBuffer.Get(pixelIndex * 4 + 3);

    if (minDistance == -1.f)
      return; //TODO: Compact? or just image subset...
    //get the initial sample position;
    vtkm::Vec3f_32 sampleLocation;
    // find the distance to the first sample
    vtkm::Float32 distance = minDistance + MeshEpsilon;
    sampleLocation = rayOrigin + distance * rayDir;
    // since the calculations are slightly different, we could hit an
    // edge case where the first sample location may not be in the data set.
    // Thus, advance to the next sample location
    while (!Locator.IsInside(sampleLocation) && distance < maxDistance)
    {
      distance += SampleDistance;
      sampleLocation = rayOrigin + distance * rayDir;
    }

    /*
            7----------6
           /|         /|
          4----------5 |
          | |        | |
          | 3--------|-2    z y
          |/         |/     |/
          0----------1      |__ x
    */
    bool newCell = true;
    vtkm::Float32 tx = 2.f;
    vtkm::Float32 ty = 2.f;
    vtkm::Float32 tz = 2.f;
    vtkm::Float32 scalar0 = 0.f;
    vtkm::Vec4f_32 sampleColor(0.f, 0.f, 0.f, 0.f);
    vtkm::Vec3f_32 bottomLeft(0.f, 0.f, 0.f);
    vtkm::Vec3f_32 invSpacing(0.f, 0.f, 0.f);
    vtkm::Id3 cell(0, 0, 0);
    while (Locator.IsInside(sampleLocation) && distance < maxDistance)
    {
      vtkm::Float32 mint = vtkm::Min(tx, vtkm::Min(ty, tz));
      vtkm::Float32 maxt = vtkm::Max(tx, vtkm::Max(ty, tz));
      if (maxt > 1.f || mint < 0.f)
        newCell = true;
      if (newCell)
      {
        Locator.LocateCell(cell, sampleLocation, invSpacing);
        vtkm::Id cellId = Locator.GetCellIndex(cell);

        scalar0 = vtkm::Float32(scalars.Get(cellId));
        vtkm::Float32 normalizedScalar = (scalar0 - MinScalar) * InverseDeltaScalar;
        vtkm::Id colorIndex =
          static_cast<vtkm::Id>(normalizedScalar * static_cast<vtkm::Float32>(ColorMapSize));
        if (colorIndex < 0)
          colorIndex = 0;
        if (colorIndex > ColorMapSize)
          colorIndex = ColorMapSize;
        sampleColor = ColorMap.Get(colorIndex);
        Locator.GetMinPoint(cell, bottomLeft);
        tx = (sampleLocation[0] - bottomLeft[0]) * invSpacing[0];
        ty = (sampleLocation[1] - bottomLeft[1]) * invSpacing[1];
        tz = (sampleLocation[2] - bottomLeft[2]) * invSpacing[2];
        newCell = false;
      }

      // just repeatably composite
      vtkm::Float32 alpha = sampleColor[3] * (1.f - color[3]);
      color[0] = color[0] + sampleColor[0] * alpha;
      color[1] = color[1] + sampleColor[1] * alpha;
      color[2] = color[2] + sampleColor[2] * alpha;
      color[3] = alpha + color[3];
      //advance
      distance += SampleDistance;
      sampleLocation = sampleLocation + SampleDistance * rayDir;

      if (color[3] >= 1.f)
        break;
      tx = (sampleLocation[0] - bottomLeft[0]) * invSpacing[0];
      ty = (sampleLocation[1] - bottomLeft[1]) * invSpacing[1];
      tz = (sampleLocation[2] - bottomLeft[2]) * invSpacing[2];
    }
    color[0] = vtkm::Min(color[0], 1.f);
    color[1] = vtkm::Min(color[1], 1.f);
    color[2] = vtkm::Min(color[2], 1.f);
    color[3] = vtkm::Min(color[3], 1.f);

    colorBuffer.Set(pixelIndex * 4 + 0, color[0]);
    colorBuffer.Set(pixelIndex * 4 + 1, color[1]);
    colorBuffer.Set(pixelIndex * 4 + 2, color[2]);
    colorBuffer.Set(pixelIndex * 4 + 3, color[3]);
  }
}; //SamplerCell

class CalcRayStart : public vtkm::worklet::WorkletMapField
{
  vtkm::Float32 Xmin;
  vtkm::Float32 Ymin;
  vtkm::Float32 Zmin;
  vtkm::Float32 Xmax;
  vtkm::Float32 Ymax;
  vtkm::Float32 Zmax;

public:
  VTKM_CONT
  CalcRayStart(const vtkm::Bounds boundingBox)
  {
    Xmin = static_cast<vtkm::Float32>(boundingBox.X.Min);
    Xmax = static_cast<vtkm::Float32>(boundingBox.X.Max);
    Ymin = static_cast<vtkm::Float32>(boundingBox.Y.Min);
    Ymax = static_cast<vtkm::Float32>(boundingBox.Y.Max);
    Zmin = static_cast<vtkm::Float32>(boundingBox.Z.Min);
    Zmax = static_cast<vtkm::Float32>(boundingBox.Z.Max);
  }

  VTKM_EXEC
  vtkm::Float32 rcp(vtkm::Float32 f) const { return 1.0f / f; }

  VTKM_EXEC
  vtkm::Float32 rcp_safe(vtkm::Float32 f) const { return rcp((fabs(f) < 1e-8f) ? 1e-8f : f); }

  using ControlSignature = void(FieldIn, FieldOut, FieldInOut, FieldInOut, FieldIn);
  using ExecutionSignature = void(_1, _2, _3, _4, _5);
  template <typename Precision>
  VTKM_EXEC void operator()(const vtkm::Vec<Precision, 3>& rayDir,
                            vtkm::Float32& minDistance,
                            vtkm::Float32& distance,
                            vtkm::Float32& maxDistance,
                            const vtkm::Vec<Precision, 3>& rayOrigin) const
  {
    vtkm::Float32 dirx = static_cast<vtkm::Float32>(rayDir[0]);
    vtkm::Float32 diry = static_cast<vtkm::Float32>(rayDir[1]);
    vtkm::Float32 dirz = static_cast<vtkm::Float32>(rayDir[2]);
    vtkm::Float32 origx = static_cast<vtkm::Float32>(rayOrigin[0]);
    vtkm::Float32 origy = static_cast<vtkm::Float32>(rayOrigin[1]);
    vtkm::Float32 origz = static_cast<vtkm::Float32>(rayOrigin[2]);

    vtkm::Float32 invDirx = rcp_safe(dirx);
    vtkm::Float32 invDiry = rcp_safe(diry);
    vtkm::Float32 invDirz = rcp_safe(dirz);

    vtkm::Float32 odirx = origx * invDirx;
    vtkm::Float32 odiry = origy * invDiry;
    vtkm::Float32 odirz = origz * invDirz;

    vtkm::Float32 xmin = Xmin * invDirx - odirx;
    vtkm::Float32 ymin = Ymin * invDiry - odiry;
    vtkm::Float32 zmin = Zmin * invDirz - odirz;
    vtkm::Float32 xmax = Xmax * invDirx - odirx;
    vtkm::Float32 ymax = Ymax * invDiry - odiry;
    vtkm::Float32 zmax = Zmax * invDirz - odirz;


    minDistance = vtkm::Max(
      vtkm::Max(vtkm::Max(vtkm::Min(ymin, ymax), vtkm::Min(xmin, xmax)), vtkm::Min(zmin, zmax)),
      minDistance);
    vtkm::Float32 exitDistance =
      vtkm::Min(vtkm::Min(vtkm::Max(ymin, ymax), vtkm::Max(xmin, xmax)), vtkm::Max(zmin, zmax));
    maxDistance = vtkm::Min(maxDistance, exitDistance);
    if (maxDistance < minDistance)
    {
      minDistance = -1.f; //flag for miss
    }
    else
    {
      distance = minDistance;
    }
  }
}; //class CalcRayStart

struct MpiTypes
{
  MPI_Datatype Vec3f_32;
  MPI_Op CoordinateReduceOp;
};

} //namespace

LitVolumeRenderer::LitVolumeRenderer()
{
  IsSceneDirty = false;
  IsUniformDataSet = true;
  SampleDistance = -1.f;
  ShadowMapSize = { 16, 16, 16 };
}

void LitVolumeRenderer::SetColorMap(const vtkm::cont::ArrayHandle<vtkm::Vec4f_32>& colorMap)
{
  ColorMap = colorMap;
}

void LitVolumeRenderer::SetData(const vtkm::cont::CoordinateSystem& coords,
                                const vtkm::cont::Field& scalarField,
                                const vtkm::cont::CellSetStructured<3>& cellset,
                                const vtkm::Range& scalarRange)
{
  IsUniformDataSet = !coords.GetData().IsType<CartesianArrayHandle>();
  IsSceneDirty = true;
  SpatialExtent = coords.GetBounds();
  SpatialExtentMagnitude = beams::Math::BoundsMagnitude<vtkm::Float32>(this->SpatialExtent);
  CoordinateSystem = coords;
  ScalarField = &scalarField;
  CellSet = cellset;
  ScalarRange = scalarRange;
}

void LitVolumeRenderer::Render(vtkm::rendering::raytracing::Ray<vtkm::Float32>& rays)
{
  const bool isSupportedField = ScalarField->IsCellField() || ScalarField->IsPointField();
  if (!isSupportedField)
  {
    throw vtkm::cont::ErrorBadValue("Field not accociated with cell set or points");
  }

  if (this->SampleDistance <= 0.f)
  {
    const vtkm::Float32 defaultNumberOfSamples = 200.f;
    this->SampleDistance = this->SpatialExtentMagnitude / defaultNumberOfSamples;
  }

  auto functor = [&](auto device)
  {
    using Device = typename std::decay_t<decltype(device)>;
    VTKM_IS_DEVICE_ADAPTER_TAG(Device);

    this->RenderOnDevice(rays, device);
    return true;
  };
  vtkm::cont::TryExecute(functor);
}

void LitVolumeRenderer::AddLight(std::shared_ptr<Light> light)
{
  using PLight = beams::rendering::PointLight<vtkm::Float32>;
  PLight* pLight = reinterpret_cast<PLight*>(light.get());
  Lights.AddLight(pLight->Position, pLight->Color, pLight->Intensity);
}

void LitVolumeRenderer::ClearLights()
{
  Lights.ClearLights();
}

void CoordinateReduceFn(void* in, void* inout, int* len, MPI_Datatype* datatype)
{
  auto inVec = reinterpret_cast<vtkm::Vec<vtkm::Float32, 3>*>(in);
  auto inoutVec = reinterpret_cast<vtkm::Vec<vtkm::Float32, 3>*>(inout);
  for (int i = 0; i < *len; ++i)
  {
    for (int j = 0; j < 3; ++j)
    {
      inoutVec[i][j] += inVec[i][j];
    }
  }
}

MpiTypes ConstructMpiTypes()
{
  MpiTypes types;
  MPI_Type_contiguous(3, MPI_FLOAT, &types.Vec3f_32);
  MPI_Type_commit(&types.Vec3f_32);

  MPI_Op_create(CoordinateReduceFn, 1, &types.CoordinateReduceOp);
  return types;
}

template <typename Precision, typename Device>
void LitVolumeRenderer::RenderOnDevice(vtkm::rendering::raytracing::Ray<Precision>& rays, Device)
{
  auto mpi = beams::mpi::MpiEnv::GetInstance();
  vtkm::cont::Token token;

  vtkm::cont::Timer testTimer;
  testTimer.Start();
  vtkm::Float32 meshEpsilon = this->SpatialExtentMagnitude * 0.001f;
  vtkm::cont::Timer phase1ShadowMapTimer{ Device() };
  phase1ShadowMapTimer.Start();
  MPI_Comm mpiComm = vtkmdiy::mpi::mpi_cast(mpi->Comm.handle());

  auto bounds = this->SpatialExtent;

  auto opacityMap = ShadowVolume(bounds, this->ShadowMapSize, this->ShadowMapNumSteps);
  opacityMap.SetLight(Lights.Locations[0], Lights.Colors[0], Lights.Intensities[0]);
  opacityMap.SetDensityCorrectionRatio(this->DensityCorrectionRatio);

  if (this->IsUniformDataSet)
  {
    vtkm::cont::ArrayHandleUniformPointCoordinates vertices;
    vertices = this->CoordinateSystem.GetData()
                 .AsArrayHandle<vtkm::cont::ArrayHandleUniformPointCoordinates>();
    CellLocatorUniform<Device> locator(vertices, token);
    opacityMap.Build(locator, *ScalarField, ScalarRange, this->ColorMap, meshEpsilon, Device());
  }
  else
  {
    CartesianArrayHandle vertices;
    vertices = this->CoordinateSystem.GetDataAsDefaultFloat().AsArrayHandle<CartesianArrayHandle>();
    CellLocatorRectilinear<Device> locator(vertices, token);
    opacityMap.Build(locator, *ScalarField, ScalarRange, this->ColorMap, meshEpsilon, Device());
  }

  using PhotonMapEstimatorType = OpacityMapEstimator<Device>;
  MPI_Barrier(mpiComm);
  phase1ShadowMapTimer.Stop();
  this->ProfilerTimes.push_back({
    .Name = P1ShadowVolumeGenLabel,
    .Time = phase1ShadowMapTimer.GetElapsedTime(),
  });

  vtkm::cont::Timer phase2TotalTimer;
  phase2TotalTimer.Start();
  MpiTypes MPI_TYPES = ConstructMpiTypes();

  vtkm::cont::Timer phase2FaceOpacitiesTimer;
  phase2FaceOpacitiesTimer.Start();

  // Extract the 6 faces of the shadow map
  auto opacityMapPDims = opacityMap.Coordinates.GetDimensions();

  vtkm::Id leftFaceOpacitiesCount = opacityMapPDims[1] * opacityMapPDims[2] * 2;
  vtkm::Id rightFaceOpacitiesCount = opacityMapPDims[1] * opacityMapPDims[2] * 2;
  vtkm::Id bottomFaceOpacitiesCount = opacityMapPDims[0] * opacityMapPDims[2] * 2;
  vtkm::Id topFaceOpacitiesCount = opacityMapPDims[0] * opacityMapPDims[2] * 2;
  vtkm::Id frontFaceOpacitiesCount = opacityMapPDims[0] * opacityMapPDims[1] * 2;
  vtkm::Id backFaceOpacitiesCount = opacityMapPDims[0] * opacityMapPDims[1] * 2;
  vtkm::Id faceOpacitiesCount = leftFaceOpacitiesCount + rightFaceOpacitiesCount;
  faceOpacitiesCount += bottomFaceOpacitiesCount + topFaceOpacitiesCount;
  faceOpacitiesCount += frontFaceOpacitiesCount + backFaceOpacitiesCount;

  auto globalFaceOpacitiesCount = faceOpacitiesCount * mpi->Size;
  std::vector<vtkm::Vec3f_32> faceCoordinatesV(globalFaceOpacitiesCount, vtkm::Vec3f_32(0.0f));
  std::vector<Precision> faceOpacitiesV(globalFaceOpacitiesCount, 0.0f);
  int faceOpacitiesIndex = mpi->Rank * faceOpacitiesCount;
  auto opacitiesPortal = opacityMap.Opacities.ReadPortal();
  auto coordinatesPortal = opacityMap.Coordinates.ReadPortal();

  // Left face: Id = 0
  for (int z = 0; z < opacityMapPDims[2]; ++z)
  {
    for (int y = 0; y < opacityMapPDims[1]; ++y)
    {
      for (int x = 0; x < 2; ++x)
      {
        int index = opacityMapPDims[0] * opacityMapPDims[1] * z + opacityMapPDims[0] * y + x;
        faceCoordinatesV[faceOpacitiesIndex] = coordinatesPortal.Get(index);
        faceOpacitiesV[faceOpacitiesIndex++] = opacitiesPortal.Get(index);
      }
    }
  }

  // Right face: Id = 1
  for (int z = 0; z < opacityMapPDims[2]; ++z)
  {
    for (int y = 0; y < opacityMapPDims[1]; ++y)
    {
      for (int x = opacityMapPDims[0] - 2; x < opacityMapPDims[0]; ++x)
      {
        int index = opacityMapPDims[0] * opacityMapPDims[1] * z + opacityMapPDims[0] * y + x;
        faceCoordinatesV[faceOpacitiesIndex] = coordinatesPortal.Get(index);
        faceOpacitiesV[faceOpacitiesIndex++] = opacitiesPortal.Get(index);
      }
    }
  }

  // Bottom face: Id = 2
  for (int z = 0; z < opacityMapPDims[2]; ++z)
  {
    for (int y = 0; y < 2; ++y)
    {
      for (int x = 0; x < opacityMapPDims[0]; ++x)
      {
        int index = opacityMapPDims[0] * opacityMapPDims[1] * z + opacityMapPDims[0] * y + x;
        faceCoordinatesV[faceOpacitiesIndex] = coordinatesPortal.Get(index);
        faceOpacitiesV[faceOpacitiesIndex++] = opacitiesPortal.Get(index);
      }
    }
  }

  // Top face: Id = 3
  for (int z = 0; z < opacityMapPDims[2]; ++z)
  {
    for (int y = opacityMapPDims[1] - 2; y < opacityMapPDims[1]; ++y)
    {
      for (int x = 0; x < opacityMapPDims[0]; ++x)
      {
        int index = opacityMapPDims[0] * opacityMapPDims[1] * z + opacityMapPDims[0] * y + x;
        faceCoordinatesV[faceOpacitiesIndex] = coordinatesPortal.Get(index);
        faceOpacitiesV[faceOpacitiesIndex++] = opacitiesPortal.Get(index);
      }
    }
  }

  // Front face: Id = 4
  for (int z = 0; z < 2; ++z)
  {
    for (int y = 0; y < opacityMapPDims[1]; ++y)
    {
      for (int x = 0; x < opacityMapPDims[0]; ++x)
      {
        int index = opacityMapPDims[0] * opacityMapPDims[1] * z + opacityMapPDims[0] * y + x;
        faceCoordinatesV[faceOpacitiesIndex] = coordinatesPortal.Get(index);
        faceOpacitiesV[faceOpacitiesIndex++] = opacitiesPortal.Get(index);
      }
    }
  }

  // Back face: Id = 5
  for (int z = opacityMapPDims[2] - 2; z < opacityMapPDims[2]; ++z)
  {
    for (int y = 0; y < opacityMapPDims[1]; ++y)
    {
      for (int x = 0; x < opacityMapPDims[0]; ++x)
      {
        int index = opacityMapPDims[0] * opacityMapPDims[1] * z + opacityMapPDims[0] * y + x;
        faceCoordinatesV[faceOpacitiesIndex] = coordinatesPortal.Get(index);
        faceOpacitiesV[faceOpacitiesIndex++] = opacitiesPortal.Get(index);
      }
    }
  }

  /*
  MPI_Barrier(mpiComm);
  phase2FaceOpacitiesTimer.Stop();
  this->ProfilerTimes.push_back({
    .Name = "Phase 2: Face Opacities Pack",
    .Time = phase2FaceOpacitiesTimer.GetElapsedTime(),
  });
  */

  vtkm::cont::Timer phase2FaceOpacitiesReduceTimer;
  phase2FaceOpacitiesReduceTimer.Start();
  std::vector<Precision> globalFaceOpacitiesV(globalFaceOpacitiesCount, 0.0f);
  std::vector<vtkm::Vec3f_32> globalFaceCoordinatesV(globalFaceOpacitiesCount,
                                                     vtkm::Vec3f_32(0.0f));

  MPI_Allreduce(faceCoordinatesV.data(),
                globalFaceCoordinatesV.data(),
                globalFaceOpacitiesCount,
                MPI_TYPES.Vec3f_32,
                MPI_TYPES.CoordinateReduceOp,
                mpiComm);

  MPI_Allreduce(faceOpacitiesV.data(),
                globalFaceOpacitiesV.data(),
                globalFaceOpacitiesCount,
                MPI_FLOAT,
                MPI_SUM,
                mpiComm);
  MPI_Op_free(&MPI_TYPES.CoordinateReduceOp);
  vtkm::cont::ArrayHandle<Precision> globalFaceOpacities =
    vtkm::cont::make_ArrayHandle(globalFaceOpacitiesV, vtkm::CopyFlag::On);
  vtkm::cont::ArrayHandle<vtkm::Vec3f_32> globalFaceCoordinates =
    vtkm::cont::make_ArrayHandle(globalFaceCoordinatesV, vtkm::CopyFlag::On);
  MPI_Barrier(mpiComm);
  phase2FaceOpacitiesReduceTimer.Stop();
  this->ProfilerTimes.push_back({
    .Name = "Phase 2: Face Opacities Reduce",
    .Time = phase2FaceOpacitiesReduceTimer.GetElapsedTime(),
  });

  vtkm::cont::Timer phase2NonLocalHitsTimer;
  phase2NonLocalHitsTimer.Start();
  const bool useGlancingHits = true;
  vtkm::cont::ArrayHandle<vtkm::Id> hitCounts;
  vtkm::cont::ArrayHandle<vtkm::Id> hitOffsets;
  vtkm::cont::ArrayHandle<OpacityRayBlockHit> rayHits;
  GetNonLocalHits(opacityMap.Coordinates,
                  mpi,
                  Lights,
                  *(this->BoundsMap),
                  useGlancingHits,
                  hitCounts,
                  hitOffsets,
                  rayHits,
                  this->ProfilerTimes,
                  Device());
  vtkm::cont::Algorithm::Sort(rayHits, beams::rendering::HitSort());
  /*
  MPI_Barrier(mpiComm);
  phase2NonLocalHitsTimer.Stop();
  this->ProfilerTimes.push_back({
    .Name = "Phase 2: Non-local Hits",
    .Time = phase2NonLocalHitsTimer.GetElapsedTime(),
  });
  */

  MPI_Barrier(mpiComm);
  phase2TotalTimer.Stop();
  this->ProfilerTimes.push_back({
    .Name = P2ShadowVolumeRenderLabel,
    .Time = phase2TotalTimer.GetElapsedTime(),
  });

  vtkm::cont::Timer phase3Timer;
  phase3Timer.Start();
  UpdateOpacities(opacityMap.Dims,
                  opacityMap.Spacing,
                  opacityMap.Coordinates,
                  opacityMap.Opacities,
                  mpi,
                  Lights,
                  *(this->BoundsMap),
                  useGlancingHits,
                  hitCounts,
                  hitOffsets,
                  rayHits,
                  globalFaceCoordinates,
                  globalFaceOpacities,
                  this->ProfilerTimes,
                  Device());
  MPI_Barrier(mpiComm);
  phase3Timer.Stop();
  this->ProfilerTimes.push_back({
    .Name = P3ShadowVolumeUpdateLabel,
    .Time = phase3Timer.GetElapsedTime(),
  });
  testTimer.Stop();

  vtkm::cont::Timer phase4RenderTimer{ Device() };
  phase4RenderTimer.Start();
  vtkm::cont::Timer timer{ Device() };
  timer.Start();
  vtkm::worklet::DispatcherMapField<CalcRayStart> calcRayStartDispatcher(
    CalcRayStart(this->SpatialExtent));
  calcRayStartDispatcher.SetDevice(Device());
  calcRayStartDispatcher.Invoke(
    rays.Dir, rays.MinDistance, rays.Distance, rays.MaxDistance, rays.Origin);

  const bool isAssocPoints = ScalarField->IsPointField();
  vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::UInt32, 2>> seeds;
  seeds.Allocate(rays.Dir.GetNumberOfValues());
  vtkm::rendering::raytracing::seedRng(seeds);
  if (IsUniformDataSet)
  {
    vtkm::cont::ArrayHandleUniformPointCoordinates vertices;
    vertices =
      CoordinateSystem.GetData().AsArrayHandle<vtkm::cont::ArrayHandleUniformPointCoordinates>();
    CellLocatorUniform<Device> locator(vertices, token);

    if (isAssocPoints)
    {
      using SamplerType = Sampler<Device, CellLocatorUniform<Device>, PhotonMapEstimatorType>;
      vtkm::worklet::DispatcherMapField<SamplerType> samplerDispatcher(
        SamplerType(ColorMap,
                    vtkm::Float32(ScalarRange.Min),
                    vtkm::Float32(ScalarRange.Max),
                    SampleDistance,
                    locator,
                    meshEpsilon,
                    opacityMap.PrepareForExecution(Device(), token),
                    this->NumShadowSamples,
                    DensityCorrectionRatio,
                    this->UseClamp,
                    this->UseReinhard,
                    token));
      samplerDispatcher.SetDevice(Device());
      samplerDispatcher.Invoke(rays.PixelIdx,
                               rays.Dir,
                               rays.Origin,
                               rays.MinDistance,
                               rays.MaxDistance,
                               rays.Buffers.at(0).Buffer,
                               vtkm::rendering::raytracing::GetScalarFieldArray(*this->ScalarField),
                               seeds);
    }
    else
    {
      vtkm::worklet::DispatcherMapField<SamplerCellAssoc<Device, CellLocatorUniform<Device>>>(
        SamplerCellAssoc<Device, CellLocatorUniform<Device>>(ColorMap,
                                                             vtkm::Float32(ScalarRange.Min),
                                                             vtkm::Float32(ScalarRange.Max),
                                                             SampleDistance,
                                                             locator,
                                                             meshEpsilon,
                                                             token))
        .Invoke(rays.Dir,
                rays.Origin,
                rays.MinDistance,
                rays.MaxDistance,
                rays.Buffers.at(0).Buffer,
                vtkm::rendering::raytracing::GetScalarFieldArray(*this->ScalarField));
    }
  }
  else
  {
    CartesianArrayHandle vertices;
    vertices = CoordinateSystem.GetData().AsArrayHandle<CartesianArrayHandle>();
    CellLocatorRectilinear<Device> locator(vertices, token);
    if (isAssocPoints)
    {
      using SamplerType = Sampler<Device, CellLocatorRectilinear<Device>, PhotonMapEstimatorType>;
      vtkm::worklet::DispatcherMapField<SamplerType> samplerDispatcher(
        SamplerType(ColorMap,
                    vtkm::Float32(ScalarRange.Min),
                    vtkm::Float32(ScalarRange.Max),
                    SampleDistance,
                    locator,
                    meshEpsilon,
                    opacityMap.PrepareForExecution(Device(), token),
                    this->NumShadowSamples,
                    DensityCorrectionRatio,
                    this->UseClamp,
                    this->UseReinhard,
                    token));
      samplerDispatcher.SetDevice(Device());
      samplerDispatcher.Invoke(rays.PixelIdx,
                               rays.Dir,
                               rays.Origin,
                               rays.MinDistance,
                               rays.MaxDistance,
                               rays.Buffers.at(0).Buffer,
                               vtkm::rendering::raytracing::GetScalarFieldArray(*this->ScalarField),
                               seeds);
    }
    else
    {
      vtkm::worklet::DispatcherMapField<SamplerCellAssoc<Device, CellLocatorRectilinear<Device>>>
        rectilinearLocatorDispatcher(
          SamplerCellAssoc<Device, CellLocatorRectilinear<Device>>(ColorMap,
                                                                   vtkm::Float32(ScalarRange.Min),
                                                                   vtkm::Float32(ScalarRange.Max),
                                                                   SampleDistance,
                                                                   locator,
                                                                   meshEpsilon,
                                                                   token));
      rectilinearLocatorDispatcher.SetDevice(Device());
      rectilinearLocatorDispatcher.Invoke(
        rays.Dir,
        rays.Origin,
        rays.MinDistance,
        rays.MaxDistance,
        rays.Buffers.at(0).Buffer,
        vtkm::rendering::raytracing::GetScalarFieldArray(*this->ScalarField));
    }
  }

  phase4RenderTimer.Stop();
  this->ProfilerTimes.push_back({
    .Name = P4VolumeRenderLabel,
    .Time = phase4RenderTimer.GetElapsedTime(),
  });
}

void LitVolumeRenderer::SetSampleDistance(const vtkm::Float32& distance)
{
  if (distance <= 0.f)
    throw vtkm::cont::ErrorBadValue("Sample distance must be positive.");
  SampleDistance = distance;
}
} // namespace rendering
} // namespace beams
