#pragma once

#include "../utils/Fmt.h"

#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCartesianProduct.h>
#include <vtkm/cont/ArrayHandleUniformPointCoordinates.h>
#include <vtkm/cont/Token.h>

namespace beams
{
namespace rendering
{
template <typename Device>
class CellLocatorUniform
{
public:
  using PointsArrayHandle = vtkm::cont::ArrayHandleUniformPointCoordinates;
  using PointsReadPortal = typename PointsArrayHandle::ReadPortalType;

  vtkm::Id3 PointDimensions;
  vtkm::Vec3f_32 Origin;
  vtkm::Vec3f_32 InvSpacing;
  vtkm::Vec3f_32 MaxPoint;
  PointsReadPortal Coordinates;

  CellLocatorUniform(const PointsArrayHandle& coordinates, vtkm::cont::Token& token)
    : Coordinates(coordinates.PrepareForInput(Device(), token))
  {
    Origin = Coordinates.GetOrigin();
    PointDimensions = Coordinates.GetDimensions();
    vtkm::Vec3f_32 spacing = Coordinates.GetSpacing();

    vtkm::Vec3f_32 unitLength;
    unitLength[0] = static_cast<vtkm::Float32>(PointDimensions[0] - 1);
    unitLength[1] = static_cast<vtkm::Float32>(PointDimensions[1] - 1);
    unitLength[2] = static_cast<vtkm::Float32>(PointDimensions[2] - 1);
    MaxPoint = Origin + spacing * unitLength;
    InvSpacing[0] = 1.f / spacing[0];
    InvSpacing[1] = 1.f / spacing[1];
    InvSpacing[2] = 1.f / spacing[2];
  }

  VTKM_EXEC
  inline bool IsInside(const vtkm::Vec3f_32& point) const
  {
    bool inside = true;
    if (point[0] < Origin[0] || point[0] > MaxPoint[0])
      inside = false;
    if (point[1] < Origin[1] || point[1] > MaxPoint[1])
      inside = false;
    if (point[2] < Origin[2] || point[2] > MaxPoint[2])
      inside = false;
    return inside;
  }

  VTKM_EXEC
  inline void GetCellIndices(const vtkm::Id3& cell, vtkm::Vec<vtkm::Id, 8>& cellIndices) const
  {
    cellIndices[0] = (cell[2] * PointDimensions[1] + cell[1]) * PointDimensions[0] + cell[0];
    cellIndices[1] = cellIndices[0] + 1;
    cellIndices[2] = cellIndices[1] + PointDimensions[0];
    cellIndices[3] = cellIndices[2] - 1;
    cellIndices[4] = cellIndices[0] + PointDimensions[0] * PointDimensions[1];
    cellIndices[5] = cellIndices[4] + 1;
    cellIndices[6] = cellIndices[5] + PointDimensions[0];
    cellIndices[7] = cellIndices[6] - 1;
  } // GetCellIndices

  VTKM_EXEC
  inline vtkm::Id GetCellIndex(const vtkm::Id3& cell) const
  {
    return (cell[2] * (PointDimensions[1] - 1) + cell[1]) * (PointDimensions[0] - 1) + cell[0];
  }

  VTKM_EXEC
  inline void LocateCell(vtkm::Id3& cell,
                         const vtkm::Vec3f_32& point,
                         vtkm::Vec3f_32& invSpacing) const
  {
    vtkm::Vec3f_32 temp = point;
    temp = temp - Origin;
    temp = temp * InvSpacing;
    if (temp[0] < 0.f)
      temp[0] = 0.f;
    if (temp[1] < 0.f)
      temp[1] = 0.f;
    if (temp[2] < 0.f)
      temp[2] = 0.f;
    if (temp[0] >= vtkm::Float32(PointDimensions[0] - 1))
      temp[0] = vtkm::Float32(PointDimensions[0] - 2);
    if (temp[1] >= vtkm::Float32(PointDimensions[1] - 1))
      temp[1] = vtkm::Float32(PointDimensions[1] - 2);
    if (temp[2] >= vtkm::Float32(PointDimensions[2] - 1))
      temp[2] = vtkm::Float32(PointDimensions[2] - 2);
    cell = temp;
    invSpacing = InvSpacing;
  }

  VTKM_EXEC
  inline void GetPoint(const vtkm::Id& index, vtkm::Vec3f_32& point) const
  {
    point = Coordinates.Get(index);
  }

  VTKM_EXEC
  inline void GetMinPoint(const vtkm::Id3& cell, vtkm::Vec3f_32& point) const
  {
    const vtkm::Id pointIndex =
      (cell[2] * PointDimensions[1] + cell[1]) * PointDimensions[0] + cell[0];
    point = Coordinates.Get(pointIndex);
  }

}; // class CellLocatorUniform

template <typename Device>
class CellLocatorRectilinear
{
public:
  using DefaultHandle = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
  using CartesianArrayHandle =
    vtkm::cont::ArrayHandleCartesianProduct<DefaultHandle, DefaultHandle, DefaultHandle>;
  using DefaultConstHandle = typename DefaultHandle::ReadPortalType;
  using CartesianConstPortal = typename CartesianArrayHandle::ReadPortalType;

  DefaultConstHandle CoordPortals[3];
  CartesianConstPortal Coordinates;
  vtkm::Id3 PointDimensions;
  vtkm::Vec3f_32 MinPoint;
  vtkm::Vec3f_32 MaxPoint;

  CellLocatorRectilinear(const CartesianArrayHandle& coordinates, vtkm::cont::Token& token)
    : Coordinates(coordinates.PrepareForInput(Device(), token))
  {
    CoordPortals[0] = Coordinates.GetFirstPortal();
    CoordPortals[1] = Coordinates.GetSecondPortal();
    CoordPortals[2] = Coordinates.GetThirdPortal();

    PointDimensions[0] = CoordPortals[0].GetNumberOfValues();
    PointDimensions[1] = CoordPortals[1].GetNumberOfValues();
    PointDimensions[2] = CoordPortals[2].GetNumberOfValues();

    MinPoint[0] = static_cast<vtkm::Float32>(CoordPortals[0].Get(0));
    MinPoint[1] = static_cast<vtkm::Float32>(CoordPortals[1].Get(0));
    MinPoint[2] = static_cast<vtkm::Float32>(CoordPortals[2].Get(0));

    MaxPoint[0] = static_cast<vtkm::Float32>(CoordPortals[0].Get(PointDimensions[0] - 1));
    MaxPoint[1] = static_cast<vtkm::Float32>(CoordPortals[1].Get(PointDimensions[1] - 1));
    MaxPoint[2] = static_cast<vtkm::Float32>(CoordPortals[2].Get(PointDimensions[2] - 1));
  }

  VTKM_EXEC
  inline bool IsInside(const vtkm::Vec3f_32& point) const
  {
    bool inside = true;
    if (point[0] < MinPoint[0] || point[0] > MaxPoint[0])
      inside = false;
    if (point[1] < MinPoint[1] || point[1] > MaxPoint[1])
      inside = false;
    if (point[2] < MinPoint[2] || point[2] > MaxPoint[2])
      inside = false;
    return inside;
  }

  VTKM_EXEC
  inline void GetCellIndices(const vtkm::Id3& cell, vtkm::Vec<vtkm::Id, 8>& cellIndices) const
  {
    cellIndices[0] = (cell[2] * PointDimensions[1] + cell[1]) * PointDimensions[0] + cell[0];
    cellIndices[1] = cellIndices[0] + 1;
    cellIndices[2] = cellIndices[1] + PointDimensions[0];
    cellIndices[3] = cellIndices[2] - 1;
    cellIndices[4] = cellIndices[0] + PointDimensions[0] * PointDimensions[1];
    cellIndices[5] = cellIndices[4] + 1;
    cellIndices[6] = cellIndices[5] + PointDimensions[0];
    cellIndices[7] = cellIndices[6] - 1;
  } // GetCellIndices

  //
  // Assumes point inside the data set
  //
  VTKM_EXEC
  inline void LocateCell(vtkm::Id3& cell,
                         const vtkm::Vec3f_32& point,
                         vtkm::Vec3f_32& invSpacing) const
  {
    cell[0] = 0;
    cell[1] = 0;
    cell[2] = 0;
    for (vtkm::Int32 dim = 0; dim < 3; ++dim)
    {
      if (point[dim] <= MinPoint[dim])
      {
        cell[dim] = 0;
        continue;
      }
      //
      // When searching for points, we consider the max value of the cell
      // to be apart of the next cell. If the point falls on the boundary of the
      // data set, then it is technically inside a cell. This checks for that case
      //
      if (point[dim] >= MaxPoint[dim])
      {
        cell[dim] = PointDimensions[dim] - 2;
        continue;
      }

      bool found = false;
      vtkm::Float32 minVal = static_cast<vtkm::Float32>(CoordPortals[dim].Get(cell[dim]));
      const vtkm::Id searchDir = (point[dim] - minVal >= 0.f) ? 1 : -1;
      vtkm::Float32 maxVal = static_cast<vtkm::Float32>(CoordPortals[dim].Get(cell[dim] + 1));

      while (!found)
      {
        if (point[dim] >= minVal && point[dim] < maxVal)
        {
          found = true;
          continue;
        }

        cell[dim] += searchDir;
        vtkm::Id nextCellId = searchDir == 1 ? cell[dim] + 1 : cell[dim];
        vtkm::Float32 next = static_cast<vtkm::Float32>(CoordPortals[dim].Get(nextCellId));
        if (searchDir == 1)
        {
          minVal = maxVal;
          maxVal = next;
        }
        else
        {
          maxVal = minVal;
          minVal = next;
        }
      }
      invSpacing[dim] = 1.f / (maxVal - minVal);
    }
  } // LocateCell

  VTKM_EXEC
  inline vtkm::Id GetCellIndex(const vtkm::Id3& cell) const
  {
    return (cell[2] * (PointDimensions[1] - 1) + cell[1]) * (PointDimensions[0] - 1) + cell[0];
  }

  VTKM_EXEC
  inline void GetPoint(const vtkm::Id& index, vtkm::Vec3f_32& point) const
  {
    point = Coordinates.Get(index);
  }

  VTKM_EXEC
  inline void GetMinPoint(const vtkm::Id3& cell, vtkm::Vec3f_32& point) const
  {
    const vtkm::Id pointIndex =
      (cell[2] * PointDimensions[1] + cell[1]) * PointDimensions[0] + cell[0];
    point = Coordinates.Get(pointIndex);
  }
}; // class CellLocatorRectilinear
}
}
