// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause
#include "vtkAnariUnstructuredVolumeMapperNode.h"
#include "vtkAnariRendererNode.h"
#include "vtkAnariProfiling.h"

#include "vtkCell.h"
#include "vtkCellTypes.h"
#include "vtkColorTransferFunction.h"
#include "vtkDataArray.h"
#include "vtkDataSet.h"
#include "vtkFloatArray.h"
#include "vtkObjectFactory.h"
#include "vtkPiecewiseFunction.h"
#include "vtkPointData.h"
#include "vtkPoints.h"
#include "vtkRenderer.h"
#include "vtkUnstructuredGrid.h"
#include "vtkUnstructuredGridVolumeMapper.h"
#include "vtkVolume.h"
#include "vtkVolumeNode.h"
#include "vtkVolumeProperty.h"
#include "vtkArrayDispatch.h"
#include "vtkLogger.h"

#include <algorithm>
#include <cassert>

#include <anari/anari_cpp.hpp>
#include <anari/anari_cpp/ext/std.h>

using vec3 = anari::std_types::vec3;

//============================================================================
namespace anari_unstructured
{
VTK_ABI_NAMESPACE_BEGIN
  struct TransferFunction
  {
    TransferFunction()
      : color()
      , opacity()
      , valueRange{ 0, 1 }
    {
    }

    std::vector<vec3> color;
    std::vector<float> opacity;
    float valueRange[2];
  };

  struct UnstructuredSpatialFieldDataWorker
  {
    UnstructuredSpatialFieldDataWorker()
      : FieldAssociation(0)
      , VectorMode(0)
      , VectorComponent(0)
      , NumComponents(1)
      , NumElements(0)
      , AnariDevice(nullptr)
      , AnariSpatialField(nullptr)
    {
    }

    //------------------------------------------------------------------------------
    template <typename ScalarArray>
    void operator()(ScalarArray* scalars)
    {
      if (this->AnariDevice == nullptr || this->AnariSpatialField == nullptr)
      {
        vtkLogF(ERROR, "[ANARI::ERROR] %s\n", "UnstructuredSpatialFieldDataWorker not properly initialized");
        return;
      }

      if (this->NumElements <= 0)
      {
        return;
      }

      auto dataArray = anari::newArray1D(this->AnariDevice, ANARI_FLOAT32, this->NumElements);
      {
        auto dataArrayPtr = anari::map<float>(this->AnariDevice, dataArray);

        for (vtkIdType i=0; i<this->NumElements; i++)
        {
          double* const val = scalars->GetTuple(i);

          if(this->NumComponents == 1)
          {
            dataArrayPtr[i] = static_cast<float>(val[0]);
          }
          else if (this->VectorMode == 0 && this->NumComponents > 1) // vector magnitude
          {
            double mag = 0;

            for(int c = 0; c < this->NumComponents; c++)
            {
              mag += val[c] * val[c];
            }

            dataArrayPtr[i] = static_cast<float>(std::sqrt(mag));
          }
          else
          {
            dataArrayPtr[i] = static_cast<float>(val[this->VectorComponent]);
          }
        }

        anari::unmap(this->AnariDevice, dataArray);
      }

      if (this->FieldAssociation) // TODO: cell.data
      {
        anari::setAndReleaseParameter(this->AnariDevice, this->AnariSpatialField, "vertex.data", dataArray);
      }
      else
      {
        anari::setAndReleaseParameter(this->AnariDevice, this->AnariSpatialField, "vertex.data", dataArray);
      }
    }

    int FieldAssociation;
    int VectorMode;
    int VectorComponent;
    int NumComponents;
    vtkIdType NumElements;
    anari::Device AnariDevice;
    anari::SpatialField AnariSpatialField;
  };

  /**
   * Converts a VTK cell enum to an Anari cell enum value. Currently the Anari
   * enum represents the number of points for the cell type.
   *
   * TODO: Determine if this will eventually lead to collisions if other cell
   *       types are supported
  */
  uint8_t VTKCellTypeToAnari(const int id)
  {
    if (id == VTK_TETRA)
      return 10;
    else if (id == VTK_PYRAMID)
      return 14;
    else if (id == VTK_WEDGE)
      return 13;
    else if (id == VTK_HEXAHEDRON)
      return 12;
    else
      return 255;
  }
VTK_ABI_NAMESPACE_END
} // anari_unstructured

VTK_ABI_NAMESPACE_BEGIN

class vtkAnariVolumeInternals
{
public:
  vtkAnariVolumeInternals(vtkAnariUnstructuredVolumeMapperNode *);
  ~vtkAnariVolumeInternals() = default;

  void UpdateTransferFunction(vtkVolume* const, const double, const double);

  void SetSpatialFieldConnectivityFromShareable(vtkUnstructuredGrid* const, anari::Device, anari::SpatialField);
  void SetSpatialFieldConnectivity(vtkUnstructuredGrid* const, anari::Device, anari::SpatialField);

  void StageVolume(const bool);

  vtkTimeStamp BuildTime;
  vtkTimeStamp PropertyTime;

  std::string LastArrayName;
  int LastArrayComponent;

  vtkAnariUnstructuredVolumeMapperNode* Owner;
  vtkAnariRendererNode* AnariRendererNode;
  anari::Volume AnariVolume;
  std::unique_ptr<anari_unstructured::TransferFunction> TransferFunction;
};

//----------------------------------------------------------------------------
vtkAnariVolumeInternals::vtkAnariVolumeInternals(
  vtkAnariUnstructuredVolumeMapperNode* owner)
  : BuildTime()
  , PropertyTime()
  , LastArrayName("")
  , LastArrayComponent(-2)
  , Owner(owner)
  , AnariRendererNode(nullptr)
  , AnariVolume(nullptr)
  , TransferFunction(nullptr)
{
}

//----------------------------------------------------------------------------
void vtkAnariVolumeInternals::StageVolume(const bool changed)
{
  vtkAnariProfiling startProfiling("vtkAnariUnstructuredVolumeMapperNode::RenderVolumes", vtkAnariProfiling::GREEN);

  if (this->AnariRendererNode != nullptr)
  {
    this->AnariRendererNode->AddVolume(this->AnariVolume, changed);
  }
}

//------------------------------------------------------------------------------
void vtkAnariVolumeInternals::UpdateTransferFunction(
  vtkVolume* const vtkVol, const double low, const double high)
{
  this->TransferFunction.reset(new anari_unstructured::TransferFunction());
  vtkVolumeProperty* volProperty = vtkVol->GetProperty();
  const int transferFunctionMode = volProperty->GetTransferFunctionMode();

  if (transferFunctionMode == vtkVolumeProperty::TF_2D)
  {
    vtkWarningWithObjectMacro(
      this->Owner, << "ANARI currently doesn't support 2D transfer functions. "
                   << "Using default RGB and Scalar transfer functions.");
  }

  if (volProperty->HasGradientOpacity())
  {
    vtkWarningWithObjectMacro(this->Owner, << "ANARI currently doesn't support gradient opacity");
  }

  vtkColorTransferFunction* colorTF = volProperty->GetRGBTransferFunction(0);
  vtkPiecewiseFunction* opacityTF = volProperty->GetScalarOpacity(0);

  // Value Range
  double tfRange[2] = { 0, -1 };

  if (transferFunctionMode == vtkVolumeProperty::TF_1D)
  {
    double* const tfRangePtr = colorTF->GetRange();
    tfRange[0] = tfRangePtr[0];
    tfRange[1] = tfRangePtr[1];
  }

  if(tfRange[1] <= tfRange[0])
  {
    tfRange[0] = low;
    tfRange[1] = high;
  }

  this->TransferFunction->valueRange[0] = static_cast<float>(tfRange[0]);
  this->TransferFunction->valueRange[1] = static_cast<float>(tfRange[1]);

  // Opacity
  int opacitySize = this->Owner->GetOpacitySize();
  this->TransferFunction->opacity.resize(opacitySize);
  opacityTF->GetTable(tfRange[0], tfRange[1], opacitySize, &this->TransferFunction->opacity[0]);

  // Color
  int colorSize = this->Owner->GetColorSize();
  float colorArray[colorSize * 3];
  colorTF->GetTable(tfRange[0], tfRange[1], colorSize, &colorArray[0]);

  for (int i = 0, j = 0; i < colorSize; i++, j += 3)
  {
    this->TransferFunction->color.emplace_back(
      vec3{ colorArray[j], colorArray[j + 1], colorArray[j + 2] });
  }
}

//----------------------------------------------------------------------------
void vtkAnariVolumeInternals::SetSpatialFieldConnectivityFromShareable(vtkUnstructuredGrid* const dataSet,
                                                                       anari::Device anariDevice,
                                                                       anari::SpatialField anariSpatialField)
{
  // Cell Type
  auto vtkCellTypes = dataSet->GetCellTypesArray();
  vtkIdType numCellTypes = vtkCellTypes->GetNumberOfValues();

  auto cellTypeArray = anari::newArray1D(anariDevice, ANARI_UINT8, numCellTypes);
  {
    auto cellTypeArrayPtr = anari::map<uint8_t>(anariDevice, cellTypeArray);

    for (vtkIdType i=0; i<numCellTypes; i++)
    {
      cellTypeArrayPtr[i] = anari_unstructured::VTKCellTypeToAnari(vtkCellTypes->GetValue(i));
    }

    anari::unmap(anariDevice, cellTypeArray);
  }

  anari::setAndReleaseParameter(anariDevice, anariSpatialField, "cell.type", cellTypeArray);

  // Cell Index
  auto cellArray = dataSet->GetCells();
  auto vtkCellIndexArray = cellArray->GetOffsetsArray();
  vtkIdType numCellIndexArray = vtkCellIndexArray->GetNumberOfTuples() - 1; // [inclusive, exclusive)
  vtkIdType numCellIndex = numCellIndexArray <= numCellTypes ? numCellIndexArray
                                                             : numCellTypes;
  auto cellIndexArray = anari::newArray1D(anariDevice, ANARI_UINT64, numCellIndex);
  {
    auto cellIndexArrayPtr = anari::map<uint64_t>(anariDevice, cellIndexArray);

    for(vtkIdType i=0; i<numCellIndex; i++)
    {
      cellIndexArrayPtr[i] = static_cast<uint64_t>(vtkCellIndexArray->GetTuple1(i));
    }

    anari::unmap(anariDevice, cellIndexArray);
  }

  anari::setAndReleaseParameter(anariDevice, anariSpatialField, "cell.index", cellIndexArray);

  // Index
  auto vtkConnectivityArray = cellArray->GetConnectivityArray();
  vtkIdType numIndex = vtkConnectivityArray->GetNumberOfTuples();

  auto indexArray = anari::newArray1D(anariDevice, ANARI_UINT64, numIndex);
  {
    auto indexArrayPtr = anari::map<uint64_t>(anariDevice, indexArray);

    for(vtkIdType i=0; i<numIndex; i++)
    {
      indexArrayPtr[i] = static_cast<uint64_t>(vtkConnectivityArray->GetTuple1(i));
    }

    anari::unmap(anariDevice, indexArray);
  }

  anari::setAndReleaseParameter(anariDevice, anariSpatialField, "index", indexArray);
}

//----------------------------------------------------------------------------
void vtkAnariVolumeInternals::SetSpatialFieldConnectivity(vtkUnstructuredGrid* const dataSet,
                                                          anari::Device anariDevice,
                                                          anari::SpatialField anariSpatialField)
{
  vtkIdType numberOfCells = dataSet->GetNumberOfCells();

  std::vector<uint8_t> cellTypeVector;
  cellTypeVector.resize(numberOfCells);

  std::vector<uint64_t> cellIndexVector;
  cellIndexVector.resize(numberOfCells);

  std::vector<uint64_t> indexVector;

  for (vtkIdType i = 0; i < numberOfCells; i++)
  {
    cellIndexVector[i] = indexVector.size();
    vtkCell* cell = dataSet->GetCell(i);

    if (cell->GetCellType() == VTK_TETRA)
    {
      cellTypeVector[i] = anari_unstructured::VTKCellTypeToAnari(VTK_TETRA);
      for (int j = 0; j < 4; j++)
      {
        indexVector.push_back(static_cast<uint64_t>(cell->GetPointId(j)));
      }
    }
    else if (cell->GetCellType() == VTK_HEXAHEDRON)
    {
      cellTypeVector[i] = anari_unstructured::VTKCellTypeToAnari(VTK_HEXAHEDRON);
      for (int j = 0; j < 8; j++)
      {
        indexVector.push_back(static_cast<uint64_t>(cell->GetPointId(j)));
      }
    }
    else if (cell->GetCellType() == VTK_WEDGE)
    {
      cellTypeVector[i] = anari_unstructured::VTKCellTypeToAnari(VTK_WEDGE);
      for (int j = 0; j < 6; ++j)
      {
        indexVector.push_back(static_cast<uint64_t>(cell->GetPointId(j)));
      }
    }
    else if (cell->GetCellType() == VTK_PYRAMID)
    {
      cellTypeVector[i] = anari_unstructured::VTKCellTypeToAnari(VTK_PYRAMID);
      for (int j = 0; j < 5; ++j)
      {
        indexVector.push_back(static_cast<uint64_t>(cell->GetPointId(j)));
      }
    }
  }

  auto cellTypeArray = anari::newArray1D(anariDevice, cellTypeVector.data(), cellTypeVector.size());
  anari::setAndReleaseParameter(anariDevice, anariSpatialField, "cell.type", cellTypeArray);

  auto cellIndexArray = anari::newArray1D(anariDevice, cellIndexVector.data(), cellIndexVector.size());
  anari::setAndReleaseParameter(anariDevice, anariSpatialField, "cell.index", cellIndexArray);

  auto indexArray = anari::newArray1D(anariDevice, indexVector.data(), indexVector.size());
  anari::setAndReleaseParameter(anariDevice, anariSpatialField, "index", indexArray);
}

//============================================================================
vtkStandardNewMacro(vtkAnariUnstructuredVolumeMapperNode);

//------------------------------------------------------------------------------
vtkAnariUnstructuredVolumeMapperNode::vtkAnariUnstructuredVolumeMapperNode()
  : ColorSize(128)
  , OpacitySize(128)
{
  this->Internal = new vtkAnariVolumeInternals(this);
}

//------------------------------------------------------------------------------
vtkAnariUnstructuredVolumeMapperNode::~vtkAnariUnstructuredVolumeMapperNode()
{
  delete this->Internal;
}

//------------------------------------------------------------------------------
void vtkAnariUnstructuredVolumeMapperNode::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
}

//------------------------------------------------------------------------------
void vtkAnariUnstructuredVolumeMapperNode::Render(bool prepass)
{
  vtkAnariProfiling startProfiling("vtkAnariUnstructuredVolumeMapperNode::Render", vtkAnariProfiling::GREEN);

  if (prepass)
  {
    vtkUnstructuredGridVolumeMapper* mapper =
      vtkUnstructuredGridVolumeMapper::SafeDownCast(this->GetRenderable());

    if (!mapper)
    {
      vtkErrorMacro("invalid mapper");
      return;
    }

    mapper->GetInputAlgorithm()->UpdateInformation();
    mapper->GetInputAlgorithm()->Update();
    vtkUnstructuredGrid* dataSet = vtkUnstructuredGrid::SafeDownCast(mapper->GetDataSetInput());
    if (!dataSet)
    {
      return;
    }

    vtkVolumeNode* volNode = vtkVolumeNode::SafeDownCast(this->Parent);
    if (!volNode)
    {
      vtkErrorMacro("invalid volumeNode");
      return;
    }

    vtkVolume* vol = vtkVolume::SafeDownCast(volNode->GetRenderable());
    if (vol->GetVisibility() == false)
    {
      vtkDebugMacro(<< "Volume visibility off");
      return;
    }

    vtkVolumeProperty* const volProperty = vol->GetProperty();

    if (!volProperty)
    {
      // this is OK, happens in paraview client side for instance
      vtkDebugMacro(<< "Volume doesn't have property set");
      return;
    }

    this->Internal->AnariRendererNode =
      static_cast<vtkAnariRendererNode*>(this->GetFirstAncestorOfType("vtkAnariRendererNode"));
    auto anariDevice = this->Internal->AnariRendererNode->GetAnariDevice();

    int fieldAssociation;
    vtkDataArray* array =
      vtkDataArray::SafeDownCast(this->GetArrayToProcess(dataSet, fieldAssociation));
    if (!array)
    {
      // ok can happen in paraview client server mode for example
      vtkDebugMacro("VolumeMapper's Input has no data array!");
      return;
    }

    vtkUnsignedCharArray* distinctCellTypes = dataSet->GetDistinctCellTypesArray();
    for (vtkIdType cti = 0; cti < distinctCellTypes->GetNumberOfValues(); cti++)
    {
      auto ct = distinctCellTypes->GetValue(cti);
      if (ct != VTK_TETRA && ct != VTK_HEXAHEDRON && ct != VTK_WEDGE && ct != VTK_PYRAMID)
      {
        vtkWarningMacro("Unsupported voxel type " << ct);
        return;
      }
    }

    bool isNewVolume = false;

    if (this->Internal->AnariVolume == nullptr)
    {
      isNewVolume = true;
      this->Internal->AnariVolume = anari::newObject<anari::Volume>(anariDevice, "transferFunction1D");
    }

    auto anariVolume = this->Internal->AnariVolume;
    vtkColorTransferFunction* const ctf = volProperty->GetRGBTransferFunction(0);

    int const indep = volProperty->GetIndependentComponents();
    int const mode = indep ? ctf->GetVectorMode() : vtkScalarsToColors::COMPONENT;
    int const comp = indep ? ctf->GetVectorComponent() : 0;
    int const val = (mode << 6) | comp; // combine to compare as one

    if(mapper->GetDataSetInput()->GetMTime() > this->Internal->BuildTime ||
       this->Internal->LastArrayName != mapper->GetArrayName() ||
       this->Internal->LastArrayComponent != val)
    {
      this->Internal->LastArrayName = mapper->GetArrayName();
      this->Internal->LastArrayComponent = val;

      auto anariSpatialField =
        anari::newObject<anari::SpatialField>(anariDevice, "unstructured");

      vtkIdType numberOfPoints = dataSet->GetNumberOfPoints();

      // Vertex Position
      auto positionArray = anari::newArray1D(anariDevice, ANARI_FLOAT32_VEC3, numberOfPoints);
      {
        auto positionArrayPtr = anari::map<vec3>(anariDevice, positionArray);
        double point[3];

        for (vtkIdType i=0; i<numberOfPoints; i++)
        {
          dataSet->GetPoint(i, point);
          positionArrayPtr[i] = vec3{static_cast<float>(point[0]),
                                     static_cast<float>(point[1]),
                                     static_cast<float>(point[2])};
        }

        anari::unmap(anariDevice, positionArray);
      }

      anari::setAndReleaseParameter(anariDevice, anariSpatialField, "vertex.position", positionArray);

      // Set cell.type, cell.index, and index
      if (dataSet->GetCells()->IsStorageShareable())
      {
        this->Internal->SetSpatialFieldConnectivityFromShareable(dataSet, anariDevice, anariSpatialField);
      }
      else
      {
        this->Internal->SetSpatialFieldConnectivity(dataSet, anariDevice, anariSpatialField);
      }

      // Volume data
      anari_unstructured::UnstructuredSpatialFieldDataWorker worker;
      worker.FieldAssociation = fieldAssociation;
      worker.VectorMode = mode;
      worker.VectorComponent = comp;
      worker.NumComponents = array->GetNumberOfComponents();
      worker.NumElements = (fieldAssociation ? dataSet->GetNumberOfCells()
                                             : dataSet->GetNumberOfPoints());
      worker.AnariDevice = anariDevice;
      worker.AnariSpatialField = anariSpatialField;

      using Dispatcher = vtkArrayDispatch::DispatchByValueType<vtkTypeList::Create<double, float>>;

      if (!Dispatcher::Execute(array, worker))
      {
        worker(array);
      }

      anari::commitParameters(anariDevice, anariSpatialField);
      anari::setAndReleaseParameter(anariDevice, anariVolume, "field", anariSpatialField);
    }

    // Transfer function
    if (volProperty->GetMTime() > this->Internal->PropertyTime ||
        mapper->GetDataSetInput()->GetMTime() > this->Internal->BuildTime ||
        isNewVolume)
    {
      double scalarRange[2];
      array->GetRange(scalarRange, comp);

      if (mode == 0 && array->GetNumberOfComponents() > 1) // vector magnitude
      {
        double min = 0;
        double max = 0;

        for (int c = 0; c < array->GetNumberOfComponents(); c++)
        {
          double lmin = 0;
          double lmax = 0;
          double cRange[2];
          array->GetRange(cRange, c);
          double ldist = cRange[0] * cRange[0];
          double rdist = cRange[1] * cRange[1];
          lmin = std::min(ldist, rdist);
          if (cRange[0] < 0 && cRange[1] > 0)
          {
            lmin = 0;
          }
          lmax = std::max(ldist, rdist);
          min += lmin;
          max += lmax;
        }
        scalarRange[0] = std::sqrt(min);
        scalarRange[1] = std::sqrt(max);
      }

      this->Internal->UpdateTransferFunction(vol, scalarRange[0], scalarRange[1]);
      anari_unstructured::TransferFunction* transferFunction = this->Internal->TransferFunction.get();

      anariSetParameter(anariDevice, anariVolume, "valueRange", ANARI_FLOAT32_BOX1, transferFunction->valueRange);

      auto array1DColor = anari::newArray1D(
        anariDevice, transferFunction->color.data(), transferFunction->color.size());
      anari::setAndReleaseParameter(anariDevice, anariVolume, "color", array1DColor);

      auto array1DOpacity = anari::newArray1D(
        anariDevice, transferFunction->opacity.data(), transferFunction->opacity.size());
      anari::setAndReleaseParameter(anariDevice, anariVolume, "opacity", array1DOpacity);

      anari::commitParameters(anariDevice, anariVolume);
      this->Internal->PropertyTime.Modified();
    }

    this->Internal->StageVolume(isNewVolume);
    this->RenderTime = volNode->GetMTime();
    this->Internal->BuildTime.Modified();
  }
}

VTK_ABI_NAMESPACE_END
