#pragma once

#include "../rendering/Random.h"
#include "../utils/Fmt.h"
#include "CellLocators.h"
#include "LightRayOperations.h"
#include "LightRays.h"

#include <vtkm/Bounds.h>
#include <vtkm/cont/Algorithm.h>
#include <vtkm/cont/ArrayHandleUniformPointCoordinates.h>
#include <vtkm/cont/ExecutionObjectBase.h>
#include <vtkm/exec/ParametricCoordinates.h>

#include <mpi.h>
#include <vtkm/thirdparty/diy/diy.h>
#include <vtkm/thirdparty/diy/mpi-cast.h>

namespace beams
{
namespace rendering
{
struct OpacityRayBlockHit
{
  int RayId;
  int BlockId;
  vtkm::Vec3f_32 Point;
  vtkm::Float32 RayT;
  vtkm::Float32 Opacity;
  vtkm::Id Face;
};

template <typename Device>
struct OpacityMapEstimator
{
  using PointsArrayHandle = typename CellLocatorUniform<Device>::PointsArrayHandle;
  using PointsReadPortal = typename PointsArrayHandle::ReadPortalType;
  using OpacityArrayHandle = vtkm::cont::ArrayHandle<vtkm::Float32>;
  using OpacityReadPortal = typename OpacityArrayHandle::ReadPortalType;

  PointsReadPortal Locations;
  OpacityReadPortal Opacities;
  CellLocatorUniform<Device> Locator;
  vtkm::Vec3f_32 LightColor;
  vtkm::Float32 Eps;

  OpacityMapEstimator(const PointsArrayHandle& locations,
                      const OpacityArrayHandle& opacities,
                      const CellLocatorUniform<Device>& locator,
                      const vtkm::Vec3f_32 lightColor,
                      vtkm::cont::Token& token)
    : Locations(locations.PrepareForInput(Device(), token))
    , Opacities(opacities.PrepareForInput(Device(), token))
    , Locator(locator)
    , LightColor(lightColor)
  {
    this->Eps = vtkm::Magnitude(this->Locations.Get(0) - this->Locations.Get(1));
  }

  VTKM_EXEC
  inline vtkm::Vec3f GetEstimateUsingVertices(const vtkm::Vec3f& point) const
  {
    vtkm::Id3 cell;
    vtkm::Vec3f_32 invSpacing;
    this->Locator.LocateCell(cell, point, invSpacing);

    if (!this->IsValidCell(this->Locator.GetCellIndex(cell)))
    {
      return { 1.0f, 1.0f, 1.0f };
    }
    vtkm::Vec<vtkm::Id, 8> cellIndices;
    this->Locator.GetCellIndices(cell, cellIndices);

    vtkm::Vec3f minPoint = this->Locations.Get(cellIndices[0]);
    vtkm::Vec3f maxPoint = this->Locations.Get(cellIndices[6]);

    vtkm::VecAxisAlignedPointCoordinates<3> rPoints(minPoint, maxPoint - minPoint);

    vtkm::Vec3f pcoords;
    vtkm::exec::WorldCoordinatesToParametricCoordinates(
      rPoints, point, vtkm::CellShapeTagHexahedron(), pcoords);

    vtkm::Vec<vtkm::FloatDefault, 8> scalars;
    for (vtkm::Id i = 0; i < 8; ++i)
    {
      scalars[i] = this->Opacities.Get(cellIndices[i]);
    }
    vtkm::FloatDefault opacity;
    vtkm::exec::CellInterpolate(scalars, pcoords, vtkm::CellShapeTagHexahedron(), opacity);
    vtkm::Float32 transmittance = 1.0f - opacity;
    return transmittance * this->LightColor;
  }

  VTKM_EXEC
  inline bool IsValidCell(vtkm::Id cellId) const
  {
    return cellId >= 0 && cellId < Locations.GetNumberOfValues();
  }
};

template <typename LocatorType>
struct OpacityMapGenerator : public vtkm::worklet::WorkletMapField
{
  LocatorType Locator;

  VTKM_CONT
  OpacityMapGenerator(const LocatorType& locator,
                      const vtkm::Float32& stepSize,
                      const vtkm::Float32& minScalar,
                      const vtkm::Float32& maxScalar,
                      const vtkm::Float32& maxDensity,
                      const vtkm::Vec3f& lightLoc,
                      const vtkm::Bounds& mapBounds,
                      const vtkm::Float32& densityCorrectionRatio,
                      const vtkm::Float32& epsilon)
    : Locator(locator)
    , StepSize(stepSize)
    , MinScalar(minScalar)
    , MaxDensity(maxDensity)
    , LightLoc(lightLoc)
    , MapBounds(mapBounds)
    , DensityCorrectionRatio(densityCorrectionRatio)
    , Epsilon(epsilon)
  {
    if ((maxScalar - minScalar) != 0.0f)
    {
      InverseDeltaScalar = 1.0f / (maxScalar - minScalar);
    }
    else
    {
      InverseDeltaScalar = minScalar;
    }
  }

  using ControlSignature = void(FieldIn rayOrigins,
                                FieldIn rayDirs,
                                FieldIn rayDests,
                                WholeArrayIn scalars,
                                WholeArrayIn colorMap,
                                FieldInOut opacities);
  using ExecutionSignature = void(_1, _2, _3, _4, _5, _6);

  VTKM_EXEC_CONT
  bool SegmentAABB(const vtkm::Vec3f_32& start,
                   const vtkm::Vec3f_32& end,
                   vtkm::Bounds aabb,
                   vtkm::Float32& tMin,
                   vtkm::Float32& tMax,
                   vtkm::Float32 tEps = 1e-4f) const
  {
    vtkm::Vec3f_32 direction = end - start;
    vtkm::Float32 segmentLength = vtkm::Magnitude(direction);
    direction = direction / segmentLength;

    tMin = 0.0f;
    tMax = vtkm::Infinity32();

    vtkm::Float32 txMin = (aabb.X.Min - start[0]) / direction[0];
    vtkm::Float32 txMax = (aabb.X.Max - start[0]) / direction[0];
    tMin = vtkm::Max(tMin, vtkm::Min(txMin, txMax));
    tMax = vtkm::Min(tMax, vtkm::Max(txMin, txMax));

    vtkm::Float32 tyMin = (aabb.Y.Min - start[1]) / direction[1];
    vtkm::Float32 tyMax = (aabb.Y.Max - start[1]) / direction[1];
    tMin = vtkm::Max(tMin, vtkm::Min(tyMin, tyMax));
    tMax = vtkm::Min(tMax, vtkm::Max(tyMin, tyMax));

    vtkm::Float32 tzMin = (aabb.Z.Min - start[2]) / direction[2];
    vtkm::Float32 tzMax = (aabb.Z.Max - start[2]) / direction[2];
    tMin = vtkm::Max(tMin, vtkm::Min(tzMin, tzMax));
    tMax = vtkm::Min(tMax, vtkm::Max(tzMin, tzMax));

    if (tMin < 0.0f)
    {
      tMin = tEps;
    }

    if (tMax > segmentLength)
    {
      tMax = segmentLength;
    }

    return tMin <= tMax;
  }

  template <typename ScalarPortalType, typename ColorMapType>
  VTKM_EXEC void operator()(const vtkm::Vec3f& rayOrigin,
                            const vtkm::Vec3f& rayDir,
                            const vtkm::Vec3f& rayDest,
                            ScalarPortalType& scalars,
                            const ColorMapType& colorMap,
                            vtkm::Float32& opacity) const
  {
    const vtkm::Id colorMapSize = colorMap.GetNumberOfValues() - 1;
    vtkm::Id3 cell;
    vtkm::Vec3f_32 invSpacing;

    vtkm::Float32 tMin, tMax;
    SegmentAABB(rayOrigin, rayDest, this->MapBounds, tMin, tMax);

    // Not sure if this is needed
    tMin -= this->Epsilon;

    auto start = rayOrigin + tMin * rayDir;
    auto distance = tMin;
    auto sampleLocation = start;

    opacity = 0.0f;
    vtkm::Vec<vtkm::Id, 8> cellIndices;
    vtkm::Vec<vtkm::Float32, 8> values;
    vtkm::Float32 scalar = 0.f;
    while (distance < tMax)
    {
      if (!this->Locator.IsInside(sampleLocation))
      {
        distance += this->StepSize;
        sampleLocation += this->StepSize * rayDir;
        continue;
      }
      scalar = 0.0f;
      this->Locator.LocateCell(cell, sampleLocation, invSpacing);
      this->Locator.GetCellIndices(cell, cellIndices);
      for (vtkm::Int32 i = 0; i < 8; ++i)
      {
        vtkm::Id j = cellIndices[i];
        values[i] = static_cast<vtkm::Float32>(scalars.Get(j));
      }

      vtkm::Vec3f_32 minPoint, maxPoint;
      this->Locator.GetPoint(cellIndices[0], minPoint);
      this->Locator.GetPoint(cellIndices[6], maxPoint);
      vtkm::VecAxisAlignedPointCoordinates<3> rPoints(minPoint, maxPoint - minPoint);
      vtkm::Vec<vtkm::Float32, 3> pcoords;
      vtkm::exec::WorldCoordinatesToParametricCoordinates(
        rPoints, sampleLocation, vtkm::CellShapeTagHexahedron(), pcoords);
      vtkm::exec::CellInterpolate(values, pcoords, vtkm::CellShapeTagHexahedron(), scalar);
      scalar = (scalar - MinScalar) * InverseDeltaScalar;
      vtkm::Id colorIndex =
        static_cast<vtkm::Id>(scalar * static_cast<vtkm::Float32>(colorMapSize));
      constexpr vtkm::Id zero = 0;
      colorIndex = vtkm::Max(zero, vtkm::Min(colorMapSize, colorIndex));
      vtkm::Vec<vtkm::Float32, 4> sampleColor = colorMap.Get(colorIndex);

      vtkm::Float32 localOpacity = sampleColor[3];

      localOpacity = 1.0f - vtkm::Pow(1.0f - localOpacity, this->DensityCorrectionRatio);

      opacity = opacity + (1.0f - opacity) * localOpacity;

      distance += StepSize;
      sampleLocation += StepSize * rayDir;
      if (opacity > 0.99f)
      {
        break;
      }
    }

    opacity = vtkm::Clamp(opacity, 0.0f, 1.0f);
  }

  vtkm::Float32 StepSize;
  vtkm::Float32 MinScalar;
  vtkm::Float32 InverseDeltaScalar;
  vtkm::Float32 MaxDensity;
  vtkm::Vec3f LightLoc;
  vtkm::Bounds MapBounds;
  vtkm::Float32 DensityCorrectionRatio;
  vtkm::Float32 Epsilon;
};

struct ShadowVolume : public vtkm::cont::ExecutionObjectBase
{
  VTKM_CONT
  ShadowVolume(vtkm::Bounds bounds, vtkm::Id3 dims, vtkm::Id numSamples)
    : Bounds(bounds)
    , Dims(dims)
    , NumSamples(numSamples)
    , DensityCorrectionRatio(1.0)
  {
  }

  VTKM_CONT
  void SetLight(vtkm::Vec3f_32 lightPosition,
                vtkm::Vec3f_32 lightColor,
                vtkm::Float32 lightIntensity)
  {
    this->LightPosition = lightPosition;
    this->LightColor = lightColor;
    this->LightIntensity = lightIntensity;
  }

  VTKM_CONT
  void SetDensityCorrectionRatio(vtkm::Float32 densityCorrectionRatio)
  {
    this->DensityCorrectionRatio = densityCorrectionRatio;
  }

  VTKM_CONT
  template <typename LocatorType, typename Device>
  void Build(const LocatorType& locator,
             const vtkm::cont::Field& field,
             const vtkm::Range& fieldRange,
             const vtkm::cont::ArrayHandle<vtkm::Vec4f_32>& colorMap,
             const vtkm::Float32 epsilon,
             Device)
  {
    const vtkm::Vec3f_32 minExtent{ static_cast<vtkm::Float32>(this->Bounds.X.Min),
                                    static_cast<vtkm::Float32>(this->Bounds.Y.Min),
                                    static_cast<vtkm::Float32>(this->Bounds.Z.Min) };
    const vtkm::Vec3f_32 maxExtent{ static_cast<vtkm::Float32>(this->Bounds.X.Max),
                                    static_cast<vtkm::Float32>(this->Bounds.Y.Max),
                                    static_cast<vtkm::Float32>(this->Bounds.Z.Max) };
    this->Spacing = (maxExtent - minExtent) / vtkm::Vec3f_32(this->Dims);
    const vtkm::Id3 pointDims = this->Dims + vtkm::Id3{ 1, 1, 1 };

    this->Coordinates =
      vtkm::cont::ArrayHandleUniformPointCoordinates(pointDims, minExtent, this->Spacing);

    LightRays<vtkm::Float32, Device> lightRays =
      LightRayOperations::CreateRays(this->Coordinates, this->LightPosition, Device());

    vtkm::Float32 sampleDistance =
      vtkm::Magnitude(maxExtent - minExtent) / static_cast<vtkm::Float32>(this->NumSamples);
    vtkm::cont::Invoker invoker{ Device() };
    this->Opacities.Allocate(this->Coordinates.GetNumberOfValues());
    invoker(OpacityMapGenerator<LocatorType>(locator,
                                             sampleDistance,
                                             vtkm::Float32(fieldRange.Min),
                                             vtkm::Float32(fieldRange.Max),
                                             this->GetMaxAlpha(colorMap),
                                             this->LightPosition,
                                             this->Bounds,
                                             this->DensityCorrectionRatio,
                                             epsilon),
            lightRays.Origins,
            lightRays.Dirs,
            lightRays.Dests,
            vtkm::rendering::raytracing::GetScalarFieldArray(field),
            colorMap,
            this->Opacities);
  }

  template <typename Device>
  VTKM_CONT OpacityMapEstimator<Device> PrepareForExecution(Device device,
                                                            vtkm::cont::Token& token) const
  {
    CellLocatorUniform<Device> locator(this->Coordinates, token);
    OpacityMapEstimator<Device> opacityMapEstimator(
      this->Coordinates, this->Opacities, locator, this->LightColor * this->LightIntensity, token);
    return opacityMapEstimator;
  }

private:
  VTKM_CONT vtkm::Float32 GetMaxAlpha(const vtkm::cont::ArrayHandle<vtkm::Vec4f_32>& colorMap)
  {
    vtkm::Float32 maxAlpha = 0.f;
    vtkm::Id size = colorMap.GetNumberOfValues();
    vtkm::cont::ArrayHandle<vtkm::Vec4f_32>::ReadPortalType portal = colorMap.ReadPortal();
    for (vtkm::Id i = 0; i < size; ++i)
    {
      maxAlpha = vtkm::Max(maxAlpha, portal.Get(i)[3]);
    }

    return maxAlpha;
  }

public:
  vtkm::Bounds Bounds;
  vtkm::Id3 Dims;
  vtkm::Id NumSamples;

  vtkm::Float32 DensityCorrectionRatio;

  vtkm::Vec3f_32 LightPosition;
  vtkm::Vec3f_32 LightColor;
  vtkm::Float32 LightIntensity;

  vtkm::cont::ArrayHandleUniformPointCoordinates Coordinates;
  vtkm::Vec3f_32 Spacing;
  vtkm::cont::ArrayHandle<vtkm::Float32> Opacities;
};

struct CountNonLocalBlockHits : public vtkm::worklet::WorkletMapField
{
  using ControlSignature = void(FieldIn samplePoints, ExecObject boundMap, FieldOut numBlocks);
  using ExecutionSignature = void(_1, _2, _3);

  VTKM_CONT
  CountNonLocalBlockHits(const vtkm::Id& selfBlockId,
                         const vtkm::Vec3f& lightLoc,
                         bool useGlancingHits)
    : SelfBlockId(selfBlockId)
    , LightLocation(lightLoc)
    , UseGlancingHits(useGlancingHits)
  {
  }
  template <typename BoundsMapExec>
  VTKM_EXEC void operator()(const vtkm::Vec3f& samplePoint,
                            const BoundsMapExec& boundsMap,
                            vtkm::Id& numBlocks) const
  {
    vtkm::Vec3f origin = this->LightLocation;
    numBlocks = boundsMap.FindNumSegmentBlockIntersections(
      origin, samplePoint, this->UseGlancingHits, this->SelfBlockId);
  }

  vtkm::Id SelfBlockId;
  vtkm::Vec3f LightLocation;
  bool UseGlancingHits;
};

struct HitSort
{
  VTKM_EXEC_CONT bool operator()(const OpacityRayBlockHit& h1, const OpacityRayBlockHit& h2) const
  {
    if (h1.RayId == h2.RayId)
    {
      return h1.RayT < h2.RayT;
    }
    else
    {
      return h1.RayId < h2.RayId;
    }
  }
};

struct CalculateNonLocalBlockHits : public vtkm::worklet::WorkletMapField
{
  using ControlSignature = void(FieldIn samplePoints,
                                ExecObject boundMap,
                                FieldIn hitOffsets,
                                WholeArrayInOut hits);
  using ExecutionSignature = void(InputIndex, _1 samplePoint, _2 boundsMap, _3 hitOffset, _4 hits);

  VTKM_CONT
  CalculateNonLocalBlockHits(const vtkm::Id& selfBlockId,
                             const vtkm::Id& numBlocks,
                             const vtkm::Vec3f& lightLoc,
                             bool useGlancingHits)
    : SelfBlockId(selfBlockId)
    , NumBlocks(numBlocks)
    , LightLoc(lightLoc)
    , UseGlancingHits(useGlancingHits)
  {
  }

  template <typename BoundsMapExec, typename HitsPortal>
  VTKM_EXEC void operator()(vtkm::Id inputIndex,
                            const vtkm::Vec3f& samplePoint,
                            const BoundsMapExec& boundsMap,
                            const vtkm::Id& offset,
                            HitsPortal& hits) const
  {
    vtkm::Id rayId = inputIndex;
    vtkm::Vec3f origin = LightLoc;
    vtkm::Vec3f dir = samplePoint - origin;
    vtkm::Normalize(dir);

    vtkm::Float32 tMin, tMax;

    vtkm::Id face = -1;

    vtkm::Id hitOffset = offset;
    for (vtkm::Id block = 0; block < this->NumBlocks; block++)
    {
      if (block == this->SelfBlockId)
        continue;

      bool hitsBlock =
        boundsMap.FindSegmentBlockIntersections(block, origin, samplePoint, tMin, tMax);

      if (!hitsBlock)
        continue;

      if ((!this->UseGlancingHits) && beams::Intersections::ApproxEquals(tMin, tMax))
        continue;

      OpacityRayBlockHit hit;
      hit.RayId = rayId;
      hit.BlockId = block;
      hit.RayT = tMax;
      hit.Point = origin + dir * tMax;

      vtkm::Bounds bounds = boundsMap.Bounds.Get(block);
      vtkm::Float32 xMin = static_cast<vtkm::Float32>(bounds.X.Min);
      vtkm::Float32 xMax = static_cast<vtkm::Float32>(bounds.X.Max);
      vtkm::Float32 yMin = static_cast<vtkm::Float32>(bounds.Y.Min);
      vtkm::Float32 yMax = static_cast<vtkm::Float32>(bounds.Y.Max);
      vtkm::Float32 zMin = static_cast<vtkm::Float32>(bounds.Z.Min);
      vtkm::Float32 zMax = static_cast<vtkm::Float32>(bounds.Z.Max);

      face = -1;
      if (beams::Intersections::ApproxEquals(hit.Point[0], xMin))
      {
        face = 0;
      }
      else if (beams::Intersections::ApproxEquals(hit.Point[0], xMax))
      {
        face = 1;
      }
      else if (beams::Intersections::ApproxEquals(hit.Point[1], yMin))
      {
        face = 2;
      }
      else if (beams::Intersections::ApproxEquals(hit.Point[1], yMax))
      {
        face = 3;
      }
      else if (beams::Intersections::ApproxEquals(hit.Point[2], zMin))
      {
        face = 4;
      }
      else if (beams::Intersections::ApproxEquals(hit.Point[2], zMax))
      {
        face = 5;
      }

      if (face == -1)
      {
        printf("Oops\n");
      }

      hit.Face = face;

      hits.Set(hitOffset, hit);
      hitOffset++;
    }
  }

  vtkm::Id SelfBlockId;
  vtkm::Id NumBlocks;
  vtkm::Vec3f LightLoc;
  bool UseGlancingHits;
};

struct OpacityUpdaterWithNonLocalHits : public vtkm::worklet::WorkletMapField
{
  using ControlSignature = void(FieldIn samplePoints,
                                FieldInOut originalOpacities,
                                ExecObject boundMap,
                                FieldIn hitCounts,
                                FieldIn hitOffsets,
                                WholeArrayIn hits,
                                WholeArrayIn globalFaceCoords,
                                WholeArrayIn globalFaceOpacities);
  using ExecutionSignature = void(InputIndex,
                                  _1 samplePoint,
                                  _2 originalOpacity,
                                  _3 boundsMap,
                                  _4 hitCount,
                                  _5 hitOffset,
                                  _6 hits,
                                  _7 globalFaceCoords,
                                  _8 globalFaceOpacities);

  VTKM_CONT
  OpacityUpdaterWithNonLocalHits(const vtkm::Id& selfBlockId,
                                 const vtkm::Id& numBlocks,
                                 const vtkm::Vec3f& lightLoc,
                                 const vtkm::Id3& opacityMapDims,
                                 const vtkm::Vec3f_32& opacityMapSpacing,
                                 bool useGlancingHits)
    : SelfBlockId(selfBlockId)
    , NumBlocks(numBlocks)
    , LightLoc(lightLoc)
    , OpacityMapDims(opacityMapDims)
    , UseGlancingHits(useGlancingHits)
  {
    InvOpacityMapSpacing[0] = 1.0f / opacityMapSpacing[0];
    InvOpacityMapSpacing[1] = 1.0f / opacityMapSpacing[1];
    InvOpacityMapSpacing[2] = 1.0f / opacityMapSpacing[2];
  }

  template <typename BoundsMapExec,
            typename HitsPortal,
            typename FaceCoordsPortal,
            typename FaceOpacitiesPortal>
  VTKM_EXEC void operator()(vtkm::Id inputIndex,
                            const vtkm::Vec3f& samplePoint,
                            vtkm::Float32& originalOpacity,
                            const BoundsMapExec& boundsMap,
                            const vtkm::Id& hitCount,
                            const vtkm::Id& hitOffset,
                            const HitsPortal& hits,
                            const FaceCoordsPortal& globalFaceCoords,
                            const FaceOpacitiesPortal& globalFaceOpacities) const
  {
    vtkm::Float32 otherOpacity = 0.0f;
    vtkm::Id facePCount = (this->OpacityMapDims[0] + 1) * (this->OpacityMapDims[1] + 1) * 2;
    vtkm::Id blockFacePCount = facePCount * 6;
    for (vtkm::Id i = 0; i < hitCount; ++i)
    {
      vtkm::Id3 facePDims(
        this->OpacityMapDims[0] + 1, this->OpacityMapDims[1] + 1, this->OpacityMapDims[2] + 1);
      const auto hit = hits.Get(hitOffset + i);
      vtkm::Id blockFaceOffset = hit.BlockId * blockFacePCount;
      vtkm::Id faceOffset = blockFaceOffset + hit.Face * facePCount;

      vtkm::Vec3f_32 faceOrigin = globalFaceCoords.Get(faceOffset);
      vtkm::Vec3f_32 faceInvSpacing = this->InvOpacityMapSpacing;
      if (hit.Face == 0 || hit.Face == 1)
      {
        facePDims[0] = 2;
      }
      else if (hit.Face == 2 || hit.Face == 3)
      {
        facePDims[1] = 2;
      }
      else if (hit.Face == 4 || hit.Face == 5)
      {
        facePDims[2] = 2;
      }
      vtkm::Id3 cell = this->LocateFaceCell(hit.Point, faceOrigin, facePDims, faceInvSpacing);
      vtkm::Vec<vtkm::Id, 8> cellIndices;
      this->GetCellIndices(cell, facePDims, cellIndices);
      vtkm::Vec<vtkm::Float32, 8> values;
      vtkm::Vec3f_32 minPoint1 = globalFaceCoords.Get(faceOffset + cellIndices[0]);
      vtkm::Vec3f_32 maxPoint1 = globalFaceCoords.Get(faceOffset + cellIndices[0]);
      for (vtkm::Id j = 0; j < 8; ++j)
      {
        vtkm::Id k = faceOffset + cellIndices[j];
        values[j] = globalFaceOpacities.Get(k);
        vtkm::Vec3f_32 coords = globalFaceCoords.Get(k);
        minPoint1[0] = vtkm::Min(minPoint1[0], coords[0]);
        minPoint1[1] = vtkm::Min(minPoint1[1], coords[1]);
        minPoint1[2] = vtkm::Min(minPoint1[2], coords[2]);
        maxPoint1[0] = vtkm::Max(maxPoint1[0], coords[0]);
        maxPoint1[1] = vtkm::Max(maxPoint1[1], coords[1]);
        maxPoint1[2] = vtkm::Max(maxPoint1[2], coords[2]);
      }

      vtkm::Vec3f_32 minPoint = globalFaceCoords.Get(faceOffset + cellIndices[0]);
      vtkm::Vec3f_32 maxPoint = globalFaceCoords.Get(faceOffset + cellIndices[6]);
      vtkm::VecAxisAlignedPointCoordinates<3> rPoints(minPoint, maxPoint - minPoint);
      vtkm::Vec3f_32 pCoords;
      vtkm::exec::WorldCoordinatesToParametricCoordinates(
        rPoints, hit.Point, vtkm::CellShapeTagHexahedron(), pCoords);

      vtkm::Float32 localOpacity = 0.0f;
      vtkm::exec::CellInterpolate(values, pCoords, vtkm::CellShapeTagHexahedron(), localOpacity);
      otherOpacity += localOpacity * (1.0f - otherOpacity);
    }
    originalOpacity += otherOpacity * (1.0f - originalOpacity);
    originalOpacity = vtkm::Clamp(originalOpacity, 0.0f, 1.0f);
  }

  VTKM_EXEC vtkm::Id3 LocateFaceCell(const vtkm::Vec3f_32& point,
                                     const vtkm::Vec3f_32& faceOrigin,
                                     const vtkm::Id3& facePDims,
                                     const vtkm::Vec3f_32& faceInvSpacing) const
  {
    vtkm::Vec3f_32 cellTmp = point;
    cellTmp = cellTmp - faceOrigin;
    cellTmp = cellTmp * faceInvSpacing;
    if (cellTmp[0] < 0.0f)
    {
      cellTmp[0] = 0.0f;
    }
    if (cellTmp[1] < 0.0f)
    {
      cellTmp[1] = 0.0f;
    }
    if (cellTmp[2] < 0.0f)
    {
      cellTmp[2] = 0.0f;
    }
    if (cellTmp[0] >= vtkm::Float32(facePDims[0] - 1))
    {
      cellTmp[0] = vtkm::Float32(facePDims[0] - 2);
    }
    if (cellTmp[1] >= vtkm::Float32(facePDims[1] - 1))
    {
      cellTmp[1] = vtkm::Float32(facePDims[1] - 2);
    }
    if (cellTmp[2] >= vtkm::Float32(facePDims[2] - 1))
    {
      cellTmp[2] = vtkm::Float32(facePDims[2] - 2);
    }
    /*
    cellTmp[0] = vtkm::Clamp(cellTmp[0], 0.0f, vtkm::Float32(facePDims[0] - 2));
    cellTmp[1] = vtkm::Clamp(cellTmp[1], 0.0f, vtkm::Float32(facePDims[1] - 2));
    cellTmp[2] = vtkm::Clamp(cellTmp[2], 0.0f, vtkm::Float32(facePDims[2] - 2));
    cell[0] = static_cast<vtkm::Id>(cellTmp[0]);
    cell[1] = static_cast<vtkm::Id>(cellTmp[1]);
    cell[2] = static_cast<vtkm::Id>(cellTmp[2]);
    */
    vtkm::Id3 cell = cellTmp;
    return cell;
  }

  VTKM_EXEC void GetCellIndices(const vtkm::Id3& cell,
                                const vtkm::Id3& facePDims,
                                vtkm::Vec<vtkm::Id, 8>& cellIndices) const
  {
    cellIndices[0] = (cell[2] * facePDims[1] + cell[1]) * facePDims[0] + cell[0];
    cellIndices[1] = cellIndices[0] + 1;
    cellIndices[2] = cellIndices[1] + facePDims[0];
    cellIndices[3] = cellIndices[2] - 1;
    cellIndices[4] = cellIndices[0] + facePDims[0] * facePDims[1];
    cellIndices[5] = cellIndices[4] + 1;
    cellIndices[6] = cellIndices[5] + facePDims[0];
    cellIndices[7] = cellIndices[6] - 1;
  }

  vtkm::Id SelfBlockId;
  vtkm::Id NumBlocks;
  vtkm::Vec3f LightLoc;
  vtkm::Id3 OpacityMapDims;
  vtkm::Vec3f_32 InvOpacityMapSpacing;
  bool UseGlancingHits;
};

template <typename Device>
void GetNonLocalHits(const vtkm::cont::ArrayHandleUniformPointCoordinates& coords,
                     const beams::mpi::MpiEnv* mpi,
                     const beams::rendering::Lights& lights,
                     const beams::rendering::BoundsMap& boundsMap,
                     bool useGlancingHits,
                     vtkm::cont::ArrayHandle<vtkm::Id>& hitCounts,
                     vtkm::cont::ArrayHandle<vtkm::Id>& hitOffsets,
                     vtkm::cont::ArrayHandle<OpacityRayBlockHit>& hits,
                     std::vector<beams::profiling::Record>& times,
                     Device)
{
  MPI_Comm mpiComm = vtkmdiy::mpi::mpi_cast(mpi->Comm.handle());
  vtkm::cont::Invoker invoker{ Device() };
  vtkm::cont::Timer countTimer;
  countTimer.Start();
  invoker(CountNonLocalBlockHits{ mpi->Rank, lights.Locations[0], useGlancingHits },
          coords,
          boundsMap,
          hitCounts);
  vtkm::Id totalHitCount = vtkm::cont::Algorithm::Reduce(hitCounts, 0);
  vtkm::cont::Algorithm::ScanExclusive(hitCounts, hitOffsets);
  countTimer.Stop();
  bool reportTimes = false;
  if (reportTimes)
  {
    MPI_Barrier(mpiComm);
    times.push_back({ "Phase 2: CountNonLocalBlockHits", countTimer.GetElapsedTime() });
  }

  vtkm::cont::Timer hitTimer;
  hitTimer.Start();
  hits.Allocate(totalHitCount);
  invoker(CalculateNonLocalBlockHits{ mpi->Rank, mpi->Size, lights.Locations[0], useGlancingHits },
          coords,
          boundsMap,
          hitOffsets,
          hits);
  hitTimer.Stop();
  if (reportTimes)
  {
    MPI_Barrier(mpiComm);
    times.push_back({ "Phase 2: CalculateNonLocalBlockHits", hitTimer.GetElapsedTime() });
  }
}

template <typename Device>
void UpdateOpacities(const vtkm::Id3 opacityMapDims,
                     const vtkm::Vec3f_32& opacityMapSpacing,
                     const vtkm::cont::ArrayHandleUniformPointCoordinates& coords,
                     vtkm::cont::ArrayHandle<vtkm::Float32>& opacities,
                     const beams::mpi::MpiEnv* mpi,
                     const beams::rendering::Lights& lights,
                     const beams::rendering::BoundsMap& boundsMap,
                     bool useGlancingHits,
                     vtkm::cont::ArrayHandle<vtkm::Id>& hitCounts,
                     vtkm::cont::ArrayHandle<vtkm::Id>& hitOffsets,
                     vtkm::cont::ArrayHandle<OpacityRayBlockHit>& hits,
                     vtkm::cont::ArrayHandle<vtkm::Vec3f_32>& globalFaceCoords,
                     vtkm::cont::ArrayHandle<vtkm::Float32>& globalFaceOpacities,
                     std::vector<beams::profiling::Record>& times,
                     Device)
{
  MPI_Comm mpiComm = vtkmdiy::mpi::mpi_cast(mpi->Comm.handle());

  vtkm::cont::Invoker invoker{ Device() };
  bool reportTimes = false;
  vtkm::cont::Timer updateTimer;
  updateTimer.Start();
  invoker(OpacityUpdaterWithNonLocalHits{ mpi->Rank,
                                          mpi->Size,
                                          lights.Locations[0],
                                          opacityMapDims,
                                          opacityMapSpacing,
                                          useGlancingHits },
          coords,
          opacities,
          boundsMap,
          hitCounts,
          hitOffsets,
          hits,
          globalFaceCoords,
          globalFaceOpacities);
  updateTimer.Stop();
  if (reportTimes)
  {
    MPI_Barrier(mpiComm);
    times.push_back({ "Phase 2: UpdateOpacities", updateTimer.GetElapsedTime() });
  }
}
}
}