// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause
#include "vtkAnariAMRVolumeMapperNode.h"
#include "vtkAnariRendererNode.h"
#include "vtkAnariProfiling.h"

#include "vtkAMRBox.h"
#include "vtkAMRInformation.h"
#include "vtkDataArray.h"
#include "vtkImageData.h"
#include "vtkObjectFactory.h"
#include "vtkOverlappingAMR.h"
#include "vtkRenderer.h"
#include "vtkSmartPointer.h"
#include "vtkUniformGridAMRDataIterator.h"
#include "vtkVolume.h"
#include "vtkVolumeMapper.h"
#include "vtkVolumeNode.h"
#include "vtkVolumeProperty.h"
#include "vtkLogger.h"
#include "vtkColorTransferFunction.h"
#include "vtkPiecewiseFunction.h"

#include <algorithm>

#include <anari/anari_cpp.hpp>
#include <anari/anari_cpp/ext/std.h>

using vec3 = anari::std_types::vec3;
using ivec3 = anari::std_types::ivec3;
using box3i = std::array<ivec3, 2>;

//============================================================================
namespace anari_amr
{
VTK_ABI_NAMESPACE_BEGIN
  struct TransferFunction
  {
    TransferFunction()
      : color()
      , opacity()
      , valueRange{ 0, 1 }
    {
    }

    std::vector<vec3> color;
    std::vector<float> opacity;
    float valueRange[2];
  };
VTK_ABI_NAMESPACE_END
} // anari_amr

VTK_ABI_NAMESPACE_BEGIN

class vtkAnariAMRVolumeInternals
{
public:
  vtkAnariAMRVolumeInternals(vtkAnariAMRVolumeMapperNode *);
  ~vtkAnariAMRVolumeInternals() = default;

  void UpdateTransferFunction(vtkVolume* const);
  void StageVolume(const bool);

  vtkTimeStamp BuildTime;
  vtkTimeStamp PropertyTime;

  vtkAnariAMRVolumeMapperNode* Owner;
  vtkAnariRendererNode* AnariRendererNode;
  anari::Volume AnariVolume;
  std::unique_ptr<anari_amr::TransferFunction> TransferFunction;
};

//----------------------------------------------------------------------------
vtkAnariAMRVolumeInternals::vtkAnariAMRVolumeInternals(
  vtkAnariAMRVolumeMapperNode* owner)
  : BuildTime()
  , PropertyTime()
  , Owner(owner)
  , AnariRendererNode(nullptr)
  , AnariVolume(nullptr)
  , TransferFunction(nullptr)
{
}

//----------------------------------------------------------------------------
void vtkAnariAMRVolumeInternals::StageVolume(const bool changed)
{
  vtkAnariProfiling startProfiling("vtkAnariAMRVolumeMapperNode::RenderVolumes", vtkAnariProfiling::GREEN);

  if (this->AnariRendererNode != nullptr)
  {
    this->AnariRendererNode->AddVolume(this->AnariVolume, changed);
  }
}

//------------------------------------------------------------------------------
void vtkAnariAMRVolumeInternals::UpdateTransferFunction(vtkVolume* const vtkVol)
{
  this->TransferFunction.reset(new anari_amr::TransferFunction());
  vtkVolumeProperty* volProperty = vtkVol->GetProperty();
  const int transferFunctionMode = volProperty->GetTransferFunctionMode();

  if (transferFunctionMode == vtkVolumeProperty::TF_2D)
  {
    vtkWarningWithObjectMacro(
      this->Owner, << "ANARI currently doesn't support 2D transfer functions. "
                   << "Using default RGB and Scalar transfer functions.");
  }

  if (volProperty->HasGradientOpacity())
  {
    vtkWarningWithObjectMacro(this->Owner, << "ANARI currently doesn't support gradient opacity");
  }

  vtkColorTransferFunction* colorTF = volProperty->GetRGBTransferFunction(0);
  vtkPiecewiseFunction* opacityTF = volProperty->GetScalarOpacity(0);

  // Value Range
  double tfRange[2] = { 0, 1 };

  if (transferFunctionMode == vtkVolumeProperty::TF_1D)
  {
    double* const tfRangePtr = colorTF->GetRange();
    tfRange[0] = tfRangePtr[0];
    tfRange[1] = tfRangePtr[1];
  }

  this->TransferFunction->valueRange[0] = static_cast<float>(tfRange[0]);
  this->TransferFunction->valueRange[1] = static_cast<float>(tfRange[1]);

  // Opacity
  int opacitySize = this->Owner->GetOpacitySize();
  this->TransferFunction->opacity.resize(opacitySize);
  opacityTF->GetTable(tfRange[0], tfRange[1], opacitySize, &this->TransferFunction->opacity[0]);

  // Color
  int colorSize = this->Owner->GetColorSize();
  float colorArray[colorSize * 3];
  colorTF->GetTable(tfRange[0], tfRange[1], colorSize, &colorArray[0]);

  for (int i = 0, j = 0; i < colorSize; i++, j += 3)
  {
    this->TransferFunction->color.emplace_back(
      vec3{ colorArray[j], colorArray[j + 1], colorArray[j + 2] });
  }
}

//============================================================================
vtkStandardNewMacro(vtkAnariAMRVolumeMapperNode);

//------------------------------------------------------------------------------
vtkAnariAMRVolumeMapperNode::vtkAnariAMRVolumeMapperNode()
{
  this->Internal = new vtkAnariAMRVolumeInternals(this);
}

//------------------------------------------------------------------------------
vtkAnariAMRVolumeMapperNode::~vtkAnariAMRVolumeMapperNode()
{
  delete this->Internal;
}

//------------------------------------------------------------------------------
void vtkAnariAMRVolumeMapperNode::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
}

//------------------------------------------------------------------------------
void vtkAnariAMRVolumeMapperNode::Render(bool prepass)
{
  vtkAnariProfiling startProfiling("vtkAnariAMRVolumeMapperNode::Render", vtkAnariProfiling::GREEN);

  if (prepass)
  {
    vtkVolumeMapper* mapper = vtkVolumeMapper::SafeDownCast(this->GetRenderable());

    if (!mapper)
    {
      vtkErrorMacro("invalid mapper");
      return;
    }

    if(mapper->GetDataSetInput() == nullptr)
    {
      vtkDebugMacro("No scalar input for the AMR Volume");
      return;
    }

    mapper->GetInputAlgorithm()->UpdateInformation();
    mapper->GetInputAlgorithm()->Update();

    vtkVolumeNode* volNode = vtkVolumeNode::SafeDownCast(this->Parent);
    if (!volNode)
    {
      vtkErrorMacro("invalid volumeNode");
      return;
    }

    vtkVolume* vol = vtkVolume::SafeDownCast(volNode->GetRenderable());
    if (vol->GetVisibility() == false)
    {
      vtkDebugMacro(<< "Volume visibility off");
      return;
    }

    vtkVolumeProperty* const volProperty = vol->GetProperty();

    if (!volProperty)
    {
      // this is OK, happens in paraview client side for instance
      vtkDebugMacro(<< "Volume doesn't have property set");
      return;
    }

    this->Internal->AnariRendererNode =
      static_cast<vtkAnariRendererNode*>(this->GetFirstAncestorOfType("vtkAnariRendererNode"));
    auto anariDevice = this->Internal->AnariRendererNode->GetAnariDevice();
    auto amr = vtkOverlappingAMR::SafeDownCast(mapper->GetInputDataObject(0, 0));

    if (!amr)
    {
      vtkErrorMacro("couldn't get amr data\n");
      return;
    }

    bool isNewVolume = false;

    if (this->Internal->AnariVolume == nullptr)
    {
      isNewVolume = true;
      this->Internal->AnariVolume = anari::newObject<anari::Volume>(anariDevice, "transferFunction1D");
    }

    auto anariVolume = this->Internal->AnariVolume;

    if (mapper->GetDataSetInput()->GetMTime() > this->Internal->BuildTime || isNewVolume)
    {
      auto anariSpatialField = anari::newObject<anari::SpatialField>(anariDevice, "amr");
      unsigned int lastLevel = 0;

      std::vector<anari::Array3D> brickDataVector;
      std::vector<float> cellWidthVector;
      std::vector<box3i> blockBoundsVector;
      std::vector<int> blockLevelVector;

      auto amrInfo = amr->GetAMRInfo();
      vtkSmartPointer<vtkUniformGridAMRDataIterator> amrDataIterator;
      amrDataIterator.TakeReference(vtkUniformGridAMRDataIterator::SafeDownCast(amr->NewIterator()));

      for (amrDataIterator->InitTraversal(); !amrDataIterator->IsDoneWithTraversal(); amrDataIterator->GoToNextItem())
      {
        auto level = amrDataIterator->GetCurrentLevel();

        if (level < lastLevel)
        {
          vtkErrorMacro("ANARI requires level info be ordered lowest to highest");
        }

        lastLevel = level;
        auto index = amrDataIterator->GetCurrentIndex();
        auto imageData = vtkImageData::SafeDownCast(amrDataIterator->GetCurrentDataObject());

        if (!imageData)
        {
          vtkErrorMacro("No current data object present");
          return;
        }

        const vtkAMRBox& box = amrInfo->GetAMRBox(level, index);
        auto loCorner = box.GetLoCorner();
        auto hiCorner = box.GetHiCorner();

        ivec3 lo = { loCorner[0], loCorner[1], loCorner[2] };
        ivec3 hi = { hiCorner[0], hiCorner[1], hiCorner[2] };

        int dim[3] = { hi[0] - lo[0] + 1,
                       hi[1] - lo[1] + 1,
                       hi[2] - lo[2] + 1 };

        mapper->SetScalarMode(VTK_SCALAR_MODE_USE_CELL_FIELD_DATA);
        int fieldAssociation;
        auto cellArray = vtkDataArray::SafeDownCast(this->GetArrayToProcess(imageData, fieldAssociation));

        if (!cellArray)
        {
          vtkErrorMacro("could not get cell array");
          return;
        }

        auto brickData = anari::newArray3D(anariDevice, ANARI_FLOAT32, dim[0], dim[1], dim[2]);
        {
          int totalSize = dim[0] * dim[1] * dim[2];
          auto brickDataPtr = anari::map<float>(anariDevice, brickData);

          for(vtkIdType i=0; i<cellArray->GetNumberOfTuples() && i<totalSize; i++)
          {
            brickDataPtr[i] = static_cast<float>(cellArray->GetTuple(i)[0]);
          }

          anari::unmap(anariDevice, brickData);
        }

        brickDataVector.emplace_back(brickData);
        blockLevelVector.emplace_back(level);
        box3i blockBounds = { lo, hi };
        blockBoundsVector.emplace_back(blockBounds);
      }

      // Cell Width
      double spacing[3] = { 1.0, 1.0, 1.0 };

      for (unsigned int i = 0; i < amrInfo->GetNumberOfLevels(); i++)
      {
        amrInfo->GetSpacing(i, spacing);
        cellWidthVector.emplace_back(spacing[0]);
      }

      anari::setParameterArray1D(anariDevice, anariSpatialField, "cellWidth", ANARI_FLOAT32,
                                 cellWidthVector.data(), cellWidthVector.size());
      vtkDebugMacro(<< "[ANARI::AMR] cell width count: " << cellWidthVector.size());

      // Block Bounds
      anari::setParameterArray1D(anariDevice, anariSpatialField, "block.bounds", ANARI_INT32_BOX3,
                                 blockBoundsVector.data(), blockBoundsVector.size());
      vtkDebugMacro(<< "[ANARI::AMR] block bounds count: " << blockBoundsVector.size());

      // Block Level
      anari::setParameterArray1D(anariDevice, anariSpatialField, "block.level", ANARI_INT32,
                                 blockLevelVector.data(), blockLevelVector.size());
      vtkDebugMacro(<< "[ANARI::AMR] block level count: " << blockLevelVector.size());

      // Block Data
      anari::setParameterArray1D(anariDevice, anariSpatialField, "block.data", ANARI_ARRAY3D,
                                 brickDataVector.data(), brickDataVector.size());
      vtkDebugMacro(<< "[ANARI::AMR] block data count: " << brickDataVector.size());

      for (auto brickData : brickDataVector)
      {
        anari::release(anariDevice, brickData);
      }

      anari::commitParameters(anariDevice, anariSpatialField);
      anari::setAndReleaseParameter(anariDevice, anariVolume, "field", anariSpatialField);
      anari::commitParameters(anariDevice, anariVolume);
    }

    // Transfer function
    if (volProperty->GetMTime() > this->Internal->PropertyTime || isNewVolume)
    {
      this->Internal->UpdateTransferFunction(vol);
      anari_amr::TransferFunction* transferFunction = this->Internal->TransferFunction.get();

      anariSetParameter(anariDevice, anariVolume, "valueRange", ANARI_FLOAT32_BOX1, transferFunction->valueRange);

      auto array1DColor = anari::newArray1D(
        anariDevice, transferFunction->color.data(), transferFunction->color.size());
      anari::setAndReleaseParameter(anariDevice, anariVolume, "color", array1DColor);

      auto array1DOpacity = anari::newArray1D(
        anariDevice, transferFunction->opacity.data(), transferFunction->opacity.size());
      anari::setAndReleaseParameter(anariDevice, anariVolume, "opacity", array1DOpacity);

      anari::commitParameters(anariDevice, anariVolume);
      this->Internal->PropertyTime.Modified();
    }

    this->Internal->StageVolume(isNewVolume);
    this->RenderTime = volNode->GetMTime();
    this->Internal->BuildTime.Modified();
  }
}

VTK_ABI_NAMESPACE_END
