//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
//  Copyright 2015 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
//  Copyright 2015 UT-Battelle, LLC.
//  Copyright 2015 Los Alamos National Security.
//
//  Under the terms of Contract DE-NA0003525 with NTESS,
//  the U.S. Government retains certain rights in this software.
//
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//============================================================================
#ifndef vtk_m_rendering_raytracing_MeshOracleBase
#define vtk_m_rendering_raytracing_MeshOracleBase

#include <sstream>
#include <vtkm/CellShape.h>
#include <vtkm/VirtualObjectBase.h>
#include <vtkm/cont/ArrayHandleCartesianProduct.h>
#include <vtkm/cont/DataSet.h>
#include <vtkm/cont/DeviceAdapterListTag.h>
#include <vtkm/cont/ErrorBadValue.h>
#include <vtkm/cont/VirtualObjectHandle.h>
#include <vtkm/exec/ParametricCoordinates.h>

namespace vtkm
{
namespace rendering
{
namespace raytracing
{

namespace detail
{

}

class VTKM_ALWAYS_EXPORT MeshOracleBase : public VirtualObjectBase
{
public:
  VTKM_EXEC_CONT
  virtual void FindCell(const vtkm::Vec<vtkm::Float32, 3>& point,
                        vtkm::Id& cellId,
                        vtkm::Vec<Float32, 3>& pcoords) const = 0;

  VTKM_EXEC_CONT
  virtual void FindCell(const vtkm::Vec<vtkm::Float64, 3>& point,
                        vtkm::Id& cellId,
                        vtkm::Vec<Float64, 3>& pcoords) const = 0;

  VTKM_EXEC_CONT
  virtual vtkm::Int32 GetCellIndices(vtkm::Id cellIndices[8], const vtkm::Id& cellId) const = 0;

  VTKM_EXEC_CONT
  virtual vtkm::UInt8 GetCellShape(const vtkm::Id& cellId) const = 0;

  // Only the cell set knows what the shape tag is.
  // This allows for a hex fast path without tons of templates
  VTKM_EXEC_CONT
  virtual vtkm::Float32 Interpolate(const vtkm::Vec<vtkm::Float32, 8>& scalars,
                                    const vtkm::Vec<vtkm::Float32, 3>& pcoords) const = 0;
};

// A simple concrete type to wrap MeshOracle so we can
// pass an ExeObject to worklets.
class MeshOracleWrapper
{
private:
  MeshOracleBase* Sampler;

public:
  MeshOracleWrapper() {}

  MeshOracleWrapper(MeshOracleBase* sampler)
    : Sampler(sampler){};

  VTKM_EXEC_CONT
  void FindCell(const vtkm::Vec<vtkm::Float32, 3>& point,
                vtkm::Id& cellId,
                vtkm::Vec<Float32, 3>& pcoords) const
  {
    Sampler->FindCell(point, cellId, pcoords);
  }

  VTKM_EXEC_CONT
  void FindCell(const vtkm::Vec<vtkm::Float64, 3>& point,
                vtkm::Id& cellId,
                vtkm::Vec<Float64, 3>& pcoords) const
  {
    Sampler->FindCell(point, cellId, pcoords);
  }

  VTKM_EXEC_CONT
  vtkm::Int32 GetCellIndices(vtkm::Id cellIndices[8], const vtkm::Id& cellId) const
  {
    return Sampler->GetCellIndices(cellIndices, cellId);
  }

  VTKM_EXEC_CONT
  vtkm::UInt8 GetCellShape(const vtkm::Id& cellId) const { return Sampler->GetCellShape(cellId); }

  VTKM_EXEC_CONT
  vtkm::Float32 Interpolate(const vtkm::Vec<vtkm::Float32, 8>& scalars,
                            const vtkm::Vec<vtkm::Float32, 3>& pcoords) const
  {
    return Sampler->Interpolate(scalars, pcoords);
  }
};

template <typename Device>
class VTKM_ALWAYS_EXPORT RectilinearOracle : public MeshOracleBase
{
protected:
  using DefaultHandle = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
  using CartesianArrayHandle =
    vtkm::cont::ArrayHandleCartesianProduct<DefaultHandle, DefaultHandle, DefaultHandle>;
  using DefaultConstHandle = typename DefaultHandle::ReadPortalType;
  using CartesianConstPortal = typename CartesianArrayHandle::ReadPortalType;

  DefaultConstHandle CoordPortals[3];
  CartesianConstPortal Coordinates;
  vtkm::exec::ConnectivityStructured<vtkm::TopologyElementTagPoint, vtkm::TopologyElementTagCell, 3>
    Conn;

  vtkm::Vec<vtkm::Float32, 3> MinPoint;
  vtkm::Vec<vtkm::Float32, 3> MaxPoint;
  vtkm::Id3 PointDims;
  vtkm::Id3 CellDims;

  VTKM_CONT RectilinearOracle() = default;

public:
  VTKM_CONT
  RectilinearOracle(const CartesianArrayHandle& coordinates,
                    vtkm::cont::CellSetStructured<3>& cellset)
    : Coordinates(coordinates.PrepareForInput(Device()))
    , Conn(cellset.PrepareForInput(Device(),
                                   vtkm::TopologyElementTagPoint(),
                                   vtkm::TopologyElementTagCell()))
  {
    CoordPortals[0] = Coordinates.GetFirstPortal();
    CoordPortals[1] = Coordinates.GetSecondPortal();
    CoordPortals[2] = Coordinates.GetThirdPortal();
    PointDims = Conn.GetPointDimensions();

    CellDims[0] = PointDims[0] - 1;
    CellDims[1] = PointDims[1] - 1;
    CellDims[2] = PointDims[2] - 1;

    MinPoint[0] =
      static_cast<vtkm::Float32>(coordinates.GetPortalConstControl().GetFirstPortal().Get(0));
    MinPoint[1] =
      static_cast<vtkm::Float32>(coordinates.GetPortalConstControl().GetSecondPortal().Get(0));
    MinPoint[2] =
      static_cast<vtkm::Float32>(coordinates.GetPortalConstControl().GetThirdPortal().Get(0));

    MaxPoint[0] = static_cast<vtkm::Float32>(
      coordinates.GetPortalConstControl().GetFirstPortal().Get(PointDims[0] - 1));
    MaxPoint[1] = static_cast<vtkm::Float32>(
      coordinates.GetPortalConstControl().GetSecondPortal().Get(PointDims[1] - 1));
    MaxPoint[2] = static_cast<vtkm::Float32>(
      coordinates.GetPortalConstControl().GetThirdPortal().Get(PointDims[2] - 1));
  }

  template <typename T>
  VTKM_EXEC inline bool IsInside(const vtkm::Vec<T, 3>& point) const
  {
    bool inside = true;
    if (point[0] < MinPoint[0] || point[0] > MaxPoint[0])
      inside = false;
    if (point[1] < MinPoint[1] || point[1] > MaxPoint[1])
      inside = false;
    if (point[2] < MinPoint[2] || point[2] > MaxPoint[2])
      inside = false;
    return inside;
  }

  VTKM_EXEC_CONT
  vtkm::Float32 Interpolate(const vtkm::Vec<vtkm::Float32, 8>& scalars,
                            const vtkm::Vec<vtkm::Float32, 3>& pcoords) const override
  {
    vtkm::exec::FunctorBase fake; // for some reason I have to do this
    vtkm::Float32 scalar;
    scalar = vtkm::exec::CellInterpolate(scalars, pcoords, vtkm::CellShapeTagHexahedron(), fake);
    return scalar;
  }

  VTKM_EXEC
  inline vtkm::Id GetCellIndex(const vtkm::Vec<vtkm::Id, 3>& cell) const
  {
    return (cell[2] * (CellDims[1]) + cell[1]) * (CellDims[0]) + cell[0];
  }

  VTKM_EXEC_CONT
  vtkm::Int32 GetCellIndices(vtkm::Id cellIndices[8], const vtkm::Id& cellIndex) const override
  {
    vtkm::Id3 cellId;
    cellId[0] = cellIndex % CellDims[0];
    cellId[1] = (cellIndex / CellDims[0]) % CellDims[1];
    cellId[2] = cellIndex / (CellDims[0] * CellDims[1]);
    cellIndices[0] = (cellId[2] * PointDims[1] + cellId[1]) * PointDims[0] + cellId[0];
    cellIndices[1] = cellIndices[0] + 1;
    cellIndices[2] = cellIndices[1] + PointDims[0];
    cellIndices[3] = cellIndices[2] - 1;
    cellIndices[4] = cellIndices[0] + PointDims[0] * PointDims[1];
    cellIndices[5] = cellIndices[4] + 1;
    cellIndices[6] = cellIndices[5] + PointDims[0];
    cellIndices[7] = cellIndices[6] - 1;
    return 8;
  }

  VTKM_EXEC_CONT
  vtkm::UInt8 GetCellShape(const vtkm::Id& vtkmNotUsed(cellId)) const override
  {
    return vtkm::UInt8(CELL_SHAPE_HEXAHEDRON);
  }

  template <typename T>
  VTKM_EXEC_CONT void FindCellImpl(const vtkm::Vec<T, 3>& point,
                                   vtkm::Id& cellId,
                                   vtkm::Vec<T, 3>& pcoords) const
  {
    // check is in -> search assumes this
    vtkm::Vec<vtkm::Id, 3> cell;
    cellId = -1;
    bool inside = IsInside(point);

    if (!inside)
    {
      return;
    }

    for (vtkm::Int32 dim = 0; dim < 3; ++dim)
    {
      vtkm::Id low = 0;
      vtkm::Id hi = CellDims[dim] - 1;
      const T val = point[dim];

      while (low <= hi)
      {
        vtkm::Id mid = (low + hi) / 2;
        T pmid = static_cast<T>(CoordPortals[dim].Get(mid + 1));

        if (val <= pmid)
        {
          hi = mid - 1;
        }
        else
        {
          low = mid + 1;
        }
      }

      cell[dim] = low;
    }


    vtkm::Vec<T, 3> minPoint;
    vtkm::Vec<T, 3> maxPoint;

    for (vtkm::Int32 i = 0; i < 3; ++i)
    {
      minPoint[i] = static_cast<T>(CoordPortals[i].Get(cell[i]));
      maxPoint[i] = static_cast<T>(CoordPortals[i].Get(cell[i] + 1));
    }
    vtkm::VecAxisAlignedPointCoordinates<3> rPoints(minPoint, maxPoint - minPoint);

    bool success;
    vtkm::exec::FunctorBase fake; // for some reason I have to do this
    pcoords = vtkm::exec::WorldCoordinatesToParametricCoordinates(
      rPoints, point, vtkm::CellShapeTagHexahedron(), success, fake);
    cellId = GetCellIndex(cell);
    VTKM_ASSERT(success);
  }

  VTKM_EXEC_CONT
  void FindCell(const vtkm::Vec<vtkm::Float32, 3>& point,
                vtkm::Id& cellId,
                vtkm::Vec<Float32, 3>& pcoords) const override
  {
    FindCellImpl(point, cellId, pcoords);
  }

  VTKM_EXEC_CONT
  void FindCell(const vtkm::Vec<vtkm::Float64, 3>& point,
                vtkm::Id& cellId,
                vtkm::Vec<Float64, 3>& pcoords) const override
  {
    FindCellImpl(point, cellId, pcoords);
  }
};
//class VTKM_ALWAYS_EXPORT MeshConnStructured : public MeshConnectivityBase
//{
//protected:
//  typedef typename vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Id, 4>> Id4Handle;
//  vtkm::Id3 CellDims;
//  vtkm::Id3 PointDims;
//
//  VTKM_CONT MeshConnStructured() = default;
//
//public:
//  VTKM_CONT
//  MeshConnStructured(const vtkm::Id3& cellDims, const vtkm::Id3& pointDims)
//    : CellDims(cellDims)
//    , PointDims(pointDims)
//  {
//  }
//
//  VTKM_EXEC_CONT
//  vtkm::Id GetConnectingCell(const vtkm::Id& cellId, const vtkm::Id& face) const override
//  {
//    //TODO: there is probably a better way to do this.
//    vtkm::Id3 logicalCellId;
//    logicalCellId[0] = cellId % CellDims[0];
//    logicalCellId[1] = (cellId / CellDims[0]) % CellDims[1];
//    logicalCellId[2] = cellId / (CellDims[0] * CellDims[1]);
//    if (face == 0)
//      logicalCellId[1] -= 1;
//    if (face == 2)
//      logicalCellId[1] += 1;
//    if (face == 1)
//      logicalCellId[0] += 1;
//    if (face == 3)
//      logicalCellId[0] -= 1;
//    if (face == 4)
//      logicalCellId[2] -= 1;
//    if (face == 5)
//      logicalCellId[2] += 1;
//    vtkm::Id nextCell =
//      (logicalCellId[2] * CellDims[1] + logicalCellId[1]) * CellDims[0] + logicalCellId[0];
//    bool validCell = true;
//    if (logicalCellId[0] >= CellDims[0])
//      validCell = false;
//    if (logicalCellId[1] >= CellDims[1])
//      validCell = false;
//    if (logicalCellId[2] >= CellDims[2])
//      validCell = false;
//    vtkm::Id minId = vtkm::Min(logicalCellId[0], vtkm::Min(logicalCellId[1], logicalCellId[2]));
//    if (minId < 0)
//      validCell = false;
//    if (!validCell)
//      nextCell = -1;
//    return nextCell;
//  }
//
//  VTKM_EXEC_CONT
//  vtkm::Int32 GetCellIndices(vtkm::Id cellIndices[8], const vtkm::Id& cellIndex) const override
//  {
//    vtkm::Id3 cellId;
//    cellId[0] = cellIndex % CellDims[0];
//    cellId[1] = (cellIndex / CellDims[0]) % CellDims[1];
//    cellId[2] = cellIndex / (CellDims[0] * CellDims[1]);
//    cellIndices[0] = (cellId[2] * PointDims[1] + cellId[1]) * PointDims[0] + cellId[0];
//    cellIndices[1] = cellIndices[0] + 1;
//    cellIndices[2] = cellIndices[1] + PointDims[0];
//    cellIndices[3] = cellIndices[2] - 1;
//    cellIndices[4] = cellIndices[0] + PointDims[0] * PointDims[1];
//    cellIndices[5] = cellIndices[4] + 1;
//    cellIndices[6] = cellIndices[5] + PointDims[0];
//    cellIndices[7] = cellIndices[6] - 1;
//    return 8;
//  }
//
//  VTKM_EXEC
//  vtkm::UInt8 GetCellShape(const vtkm::Id& vtkmNotUsed(cellId)) const override
//  {
//    return vtkm::UInt8(CELL_SHAPE_HEXAHEDRON);
//  }
//}; // MeshConnStructured
//
//template <typename Device>
//class VTKM_ALWAYS_EXPORT MeshConnUnstructured : public MeshConnectivityBase
//{
//protected:
//  using IdHandle = typename vtkm::cont::ArrayHandle<vtkm::Id>;
//  using UCharHandle = typename vtkm::cont::ArrayHandle<vtkm::UInt8>;
//  using IdConstPortal = typename IdHandle::ExecutionTypes<Device>::PortalConst;
//  using UCharConstPortal = typename UCharHandle::ExecutionTypes<Device>::PortalConst;
//
//  // Constant Portals for the execution Environment
//  //FaceConn
//  IdConstPortal FaceConnPortal;
//  IdConstPortal FaceOffsetsPortal;
//  //Cell Set
//  IdConstPortal CellConnPortal;
//  IdConstPortal CellOffsetsPortal;
//  UCharConstPortal ShapesPortal;
//
//  VTKM_CONT MeshConnUnstructured() = default;
//
//public:
//  VTKM_CONT
//  MeshConnUnstructured(const IdHandle& faceConnectivity,
//                       const IdHandle& faceOffsets,
//                       const IdHandle& cellConn,
//                       const IdHandle& cellOffsets,
//                       const UCharHandle& shapes)
//    : FaceConnPortal(faceConnectivity.PrepareForInput(Device()))
//    , FaceOffsetsPortal(faceOffsets.PrepareForInput(Device()))
//    , CellConnPortal(cellConn.PrepareForInput(Device()))
//    , CellOffsetsPortal(cellOffsets.PrepareForInput(Device()))
//    , ShapesPortal(shapes.PrepareForInput(Device()))
//  {
//  }
//
//  VTKM_EXEC_CONT
//  vtkm::Id GetConnectingCell(const vtkm::Id& cellId, const vtkm::Id& face) const override
//  {
//    BOUNDS_CHECK(FaceOffsetsPortal, cellId);
//    vtkm::Id cellStartIndex = FaceOffsetsPortal.Get(cellId);
//    BOUNDS_CHECK(FaceConnPortal, cellStartIndex + face);
//    return FaceConnPortal.Get(cellStartIndex + face);
//  }
//
//  //----------------------------------------------------------------------------
//  VTKM_EXEC
//  vtkm::Int32 GetCellIndices(vtkm::Id cellIndices[8], const vtkm::Id& cellId) const override
//  {
//    const vtkm::Int32 shapeId = static_cast<vtkm::Int32>(ShapesPortal.Get(cellId));
//    CellTables tables;
//    const vtkm::Int32 numIndices = tables.FaceLookUp(tables.CellTypeLookUp(shapeId), 2);
//    BOUNDS_CHECK(CellOffsetsPortal, cellId);
//    const vtkm::Id cellOffset = CellOffsetsPortal.Get(cellId);
//
//    for (vtkm::Int32 i = 0; i < numIndices; ++i)
//    {
//      BOUNDS_CHECK(CellConnPortal, cellOffset + i);
//      cellIndices[i] = CellConnPortal.Get(cellOffset + i);
//    }
//    return numIndices;
//  }
//
//  //----------------------------------------------------------------------------
//  VTKM_EXEC
//  vtkm::UInt8 GetCellShape(const vtkm::Id& cellId) const override
//  {
//    BOUNDS_CHECK(ShapesPortal, cellId)
//    return ShapesPortal.Get(cellId);
//  }
//
//}; // MeshConnUnstructured
//
//template <typename Device>
//class MeshConnSingleType : public MeshConnectivityBase
//{
//protected:
//  using IdHandle = typename vtkm::cont::ArrayHandle<vtkm::Id>;
//  using IdConstPortal = typename IdHandle::ExecutionTypes<Device>::PortalConst;
//
//  using CountingHandle = typename vtkm::cont::ArrayHandleCounting<vtkm::Id>;
//  using CountingPortal = typename CountingHandle::ExecutionTypes<Device>::PortalConst;
//  // Constant Portals for the execution Environment
//  IdConstPortal FaceConnPortal;
//  IdConstPortal CellConnectivityPortal;
//  CountingPortal CellOffsetsPortal;
//
//  vtkm::Int32 ShapeId;
//  vtkm::Int32 NumIndices;
//  vtkm::Int32 NumFaces;
//
//private:
//  VTKM_CONT
//  MeshConnSingleType() {}
//
//public:
//  VTKM_CONT
//  MeshConnSingleType(IdHandle& faceConn,
//                     IdHandle& cellConn,
//                     CountingHandle& cellOffsets,
//                     vtkm::Int32 shapeId,
//                     vtkm::Int32 numIndices,
//                     vtkm::Int32 numFaces)
//    : FaceConnPortal(faceConn.PrepareForInput(Device()))
//    , CellConnectivityPortal(cellConn.PrepareForInput(Device()))
//    , CellOffsetsPortal(cellOffsets.PrepareForInput(Device()))
//    , ShapeId(shapeId)
//    , NumIndices(numIndices)
//    , NumFaces(numFaces)
//  {
//  }
//
//  //----------------------------------------------------------------------------
//  //                       Execution Environment Methods
//  //----------------------------------------------------------------------------
//  VTKM_EXEC
//  vtkm::Id GetConnectingCell(const vtkm::Id& cellId, const vtkm::Id& face) const override
//  {
//    BOUNDS_CHECK(CellOffsetsPortal, cellId);
//    vtkm::Id cellStartIndex = cellId * NumFaces;
//    BOUNDS_CHECK(FaceConnPortal, cellStartIndex + face);
//    return FaceConnPortal.Get(cellStartIndex + face);
//  }
//
//  VTKM_EXEC
//  vtkm::Int32 GetCellIndices(vtkm::Id cellIndices[8], const vtkm::Id& cellId) const override
//  {
//    BOUNDS_CHECK(CellOffsetsPortal, cellId);
//    const vtkm::Id cellOffset = CellOffsetsPortal.Get(cellId);
//
//    for (vtkm::Int32 i = 0; i < NumIndices; ++i)
//    {
//      BOUNDS_CHECK(CellConnectivityPortal, cellOffset + i);
//      cellIndices[i] = CellConnectivityPortal.Get(cellOffset + i);
//    }
//
//    return NumIndices;
//  }
//
//  //----------------------------------------------------------------------------
//  VTKM_EXEC
//  vtkm::UInt8 GetCellShape(const vtkm::Id& vtkmNotUsed(cellId)) const override
//  {
//    return vtkm::UInt8(ShapeId);
//  }
//
//}; //MeshConn Single type specialization
//
class VTKM_ALWAYS_EXPORT OracleHandle : public vtkm::cont::VirtualObjectHandle<MeshOracleBase>
{
private:
  using Superclass = vtkm::cont::VirtualObjectHandle<MeshOracleBase>;

public:
  OracleHandle() = default;

  template <typename SamplerType, typename DeviceAdapterList = VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG>
  explicit OracleHandle(SamplerType* sampler,
                        bool aquireOwnership = true,
                        DeviceAdapterList devices = DeviceAdapterList())
    : Superclass(sampler, aquireOwnership, devices)
  {
  }
};

template <typename SamplerType, typename DeviceAdapterList = VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG>
VTKM_CONT OracleHandle make_OracleHandle(SamplerType&& func,
                                         DeviceAdapterList devices = DeviceAdapterList())
{
  using IFType = typename std::remove_reference<SamplerType>::type;
  return OracleHandle(new IFType(std::forward<SamplerType>(func)), true, devices);
}
}
}
} //namespace vtkm::rendering::raytracing

#ifdef VTKM_CUDA

// Cuda seems to have a bug where it expects the template class VirtualObjectTransfer
// to be instantiated in a consitent order among all the translation units of an
// executable. Failing to do so results in random crashes and incorrect results.
// We workaroud this issue by explicitly instantiating VirtualObjectTransfer for
// all the implicit functions here.

#include <vtkm/cont/internal/VirtualObjectTransferInstantiate.h>
VTKM_EXPLICITLY_INSTANTIATE_TRANSFER(
  vtkm::rendering::raytracing::RectilinearOracle<vtkm::cont::DeviceAdapterTagCuda>);
//VTKM_EXPLICITLY_INSTANTIATE_TRANSFER(
//  vtkm::rendering::raytracing::MeshConnUnstructured<vtkm::cont::DeviceAdapterTagCuda>);

#endif

#endif // MeshConnectivityBase
