// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause
#include "vtkAnariAMRVolumeMapperNode.h"
#include "vtkAnariRendererNode.h"
#include "vtkAnariProfiling.h"

#include "vtkAMRBox.h"
#include "vtkAMRInformation.h"
#include "vtkDataArray.h"
#include "vtkImageData.h"
#include "vtkObjectFactory.h"
#include "vtkOverlappingAMR.h"
#include "vtkRenderer.h"
#include "vtkSmartPointer.h"
#include "vtkUniformGridAMRDataIterator.h"
#include "vtkVolume.h"
#include "vtkVolumeMapper.h"
#include "vtkVolumeNode.h"
#include "vtkVolumeProperty.h"
#include "vtkLogger.h"
#include "vtkColorTransferFunction.h"
#include "vtkPiecewiseFunction.h"

#include <algorithm>
#include <limits>

#include <anari/anari_cpp.hpp>
#include <anari/anari_cpp/ext/std.h>

using vec3 = anari::std_types::vec3;
using ivec3 = anari::std_types::ivec3;

struct box3i
{
  std::array<ivec3, 2> bounds;

  void extend(const box3i& other)
  {
    bounds[0][0] = std::min(bounds[0][0], other.bounds[0][0]);
    bounds[0][1] = std::min(bounds[0][1], other.bounds[0][1]);
    bounds[0][2] = std::min(bounds[0][2], other.bounds[0][2]);
    bounds[1][0] = std::max(bounds[1][0], other.bounds[1][0]);
    bounds[1][1] = std::max(bounds[1][1], other.bounds[1][1]);
    bounds[1][2] = std::max(bounds[1][2], other.bounds[1][2]);
  }
};

//============================================================================
namespace anari_amr
{
VTK_ABI_NAMESPACE_BEGIN
  struct TransferFunction
  {
    TransferFunction()
      : color()
      , opacity()
      , valueRange{ 0, 1 }
    {
    }

    std::vector<vec3> color;
    std::vector<float> opacity;
    float valueRange[2];
  };
VTK_ABI_NAMESPACE_END
} // anari_amr

VTK_ABI_NAMESPACE_BEGIN

class vtkAnariAMRVolumeInternals
{
public:
  vtkAnariAMRVolumeInternals(vtkAnariAMRVolumeMapperNode *);
  ~vtkAnariAMRVolumeInternals() = default;

  void UpdateTransferFunction(vtkVolume* const);
  void StageVolume();

  vtkTimeStamp BuildTime;
  vtkTimeStamp PropertyTime;

  vtkAnariAMRVolumeMapperNode* Owner;
  vtkAnariRendererNode* AnariRendererNode;
  anari::Volume AnariVolume;
  std::unique_ptr<anari_amr::TransferFunction> TransferFunction;
};

//----------------------------------------------------------------------------
vtkAnariAMRVolumeInternals::vtkAnariAMRVolumeInternals(
  vtkAnariAMRVolumeMapperNode* owner)
  : BuildTime()
  , PropertyTime()
  , Owner(owner)
  , AnariRendererNode(nullptr)
  , AnariVolume(nullptr)
  , TransferFunction(nullptr)
{
}

//----------------------------------------------------------------------------
void vtkAnariAMRVolumeInternals::StageVolume()
{
  vtkAnariProfiling startProfiling("vtkAnariAMRVolumeMapperNode::RenderVolumes", vtkAnariProfiling::GREEN);

  if (this->AnariRendererNode != nullptr)
  {
    this->AnariRendererNode->AddVolume(this->AnariVolume);
  }
}

//------------------------------------------------------------------------------
void vtkAnariAMRVolumeInternals::UpdateTransferFunction(vtkVolume* const vtkVol)
{
  this->TransferFunction.reset(new anari_amr::TransferFunction());
  vtkVolumeProperty* volProperty = vtkVol->GetProperty();
  const int transferFunctionMode = volProperty->GetTransferFunctionMode();

  if (transferFunctionMode == vtkVolumeProperty::TF_2D)
  {
    vtkWarningWithObjectMacro(
      this->Owner, << "ANARI currently doesn't support 2D transfer functions. "
                   << "Using default RGB and Scalar transfer functions.");
  }

  if (volProperty->HasGradientOpacity())
  {
    vtkWarningWithObjectMacro(this->Owner, << "ANARI currently doesn't support gradient opacity");
  }

  vtkColorTransferFunction* colorTF = volProperty->GetRGBTransferFunction(0);
  vtkPiecewiseFunction* opacityTF = volProperty->GetScalarOpacity(0);

  // Value Range
  double tfRange[2] = { 0, 1 };

  if (transferFunctionMode == vtkVolumeProperty::TF_1D)
  {
    double* const tfRangePtr = colorTF->GetRange();
    tfRange[0] = tfRangePtr[0];
    tfRange[1] = tfRangePtr[1];
  }

  this->TransferFunction->valueRange[0] = static_cast<float>(tfRange[0]);
  this->TransferFunction->valueRange[1] = static_cast<float>(tfRange[1]);

  // Opacity
  int opacitySize = this->Owner->GetOpacitySize();
  this->TransferFunction->opacity.resize(opacitySize);
  opacityTF->GetTable(tfRange[0], tfRange[1], opacitySize, &this->TransferFunction->opacity[0]);

  // Color
  int colorSize = this->Owner->GetColorSize();
  float colorArray[colorSize * 3];
  colorTF->GetTable(tfRange[0], tfRange[1], colorSize, &colorArray[0]);

  for (int i = 0, j = 0; i < colorSize; i++, j += 3)
  {
    this->TransferFunction->color.emplace_back(
      vec3{ colorArray[j], colorArray[j + 1], colorArray[j + 2] });
  }
}

//============================================================================
vtkStandardNewMacro(vtkAnariAMRVolumeMapperNode);

//------------------------------------------------------------------------------
vtkAnariAMRVolumeMapperNode::vtkAnariAMRVolumeMapperNode()
{
  this->Internal = new vtkAnariAMRVolumeInternals(this);
}

//------------------------------------------------------------------------------
vtkAnariAMRVolumeMapperNode::~vtkAnariAMRVolumeMapperNode()
{
  delete this->Internal;
}

//------------------------------------------------------------------------------
void vtkAnariAMRVolumeMapperNode::Synchronize(bool prepass)
{
  vtkAnariProfiling startProfiling("vtkAnariAMRVolumeMapperNode::Synchronize", vtkAnariProfiling::GREEN);

  if (!prepass)
  {
    return;
  }

  vtkVolumeMapper* mapper = vtkVolumeMapper::SafeDownCast(this->GetRenderable());

  if (!mapper)
  {
    vtkErrorMacro("invalid mapper");
    return;
  }

  mapper->GetInputAlgorithm()->UpdateInformation();
  mapper->GetInputAlgorithm()->Update();

  vtkVolumeNode* volNode = vtkVolumeNode::SafeDownCast(this->Parent);
  if (!volNode)
  {
    vtkErrorMacro("invalid volumeNode");
    return;
  }

  vtkVolume* vol = vtkVolume::SafeDownCast(volNode->GetRenderable());

  if (vol->GetVisibility() != false)
  {
    vtkVolumeProperty* const volProperty = vol->GetProperty();

    if (!volProperty)
    {
      // this is OK, happens in paraview client side for instance
      vtkDebugMacro(<< "Volume doesn't have property set");
      return;
    }

    this->Internal->AnariRendererNode =
      static_cast<vtkAnariRendererNode*>(this->GetFirstAncestorOfType("vtkAnariRendererNode"));
    auto anariDevice = this->Internal->AnariRendererNode->GetAnariDevice();
    auto amr = vtkOverlappingAMR::SafeDownCast(mapper->GetInputDataObject(0, 0));

    if (!amr)
    {
      vtkErrorMacro("couldn't get amr data\n");
      return;
    }

    if (this->Internal->AnariVolume == nullptr)
    {
      this->Internal->AnariVolume = anari::newObject<anari::Volume>(anariDevice, "transferFunction1D");
    }

    auto anariVolume = this->Internal->AnariVolume;

    if (amr->GetMTime() > this->Internal->BuildTime)
    {
      auto anariSpatialField = anari::newObject<anari::SpatialField>(anariDevice, "amr");
      unsigned int lastLevel = 0;

      std::vector<anari::Array3D> brickDataVector;
      std::vector<float> cellWidthVector;
      // std::vector<box3i> blockBoundsVector;
      std::vector<ivec3> blockStartVector;
      std::vector<int> blockLevelVector;
      box3i voxelSpaceBounds = { ivec3{std::numeric_limits<int>::max(),
                                        std::numeric_limits<int>::max(),
                                        std::numeric_limits<int>::max()},
                                  ivec3{std::numeric_limits<int>::min(),
                                        std::numeric_limits<int>::min(),
                                        std::numeric_limits<int>::min()} };

      auto amrInfo = amr->GetAMRInfo();
      vtkSmartPointer<vtkUniformGridAMRDataIterator> amrDataIterator;
      amrDataIterator.TakeReference(vtkUniformGridAMRDataIterator::SafeDownCast(amr->NewIterator()));

      for (amrDataIterator->InitTraversal(); !amrDataIterator->IsDoneWithTraversal(); amrDataIterator->GoToNextItem())
      {
        auto level = amrDataIterator->GetCurrentLevel();

        if (level < lastLevel)
        {
          vtkErrorMacro("ANARI requires level info be ordered lowest to highest");
        }

        lastLevel = level;
        auto index = amrDataIterator->GetCurrentIndex();
        auto imageData = vtkImageData::SafeDownCast(amrDataIterator->GetCurrentDataObject());

        if (!imageData)
        {
          vtkErrorMacro("No current data object present");
          return;
        }

        const vtkAMRBox& box = amrInfo->GetAMRBox(level, index);
        auto loCorner = box.GetLoCorner();
        auto hiCorner = box.GetHiCorner();

        ivec3 lo = { loCorner[0], loCorner[1], loCorner[2] };
        ivec3 hi = { hiCorner[0], hiCorner[1], hiCorner[2] };

        int dim[3] = { hi[0] - lo[0] + 1,
                       hi[1] - lo[1] + 1,
                       hi[2] - lo[2] + 1 };

        blockStartVector.emplace_back(lo);

        mapper->SetScalarMode(VTK_SCALAR_MODE_USE_CELL_FIELD_DATA);
        int fieldAssociation;
        auto cellArray = vtkDataArray::SafeDownCast(this->GetArrayToProcess(imageData, fieldAssociation));

        if (!cellArray)
        {
          vtkErrorMacro("could not get cell array");
          return;
        }

        auto brickData = anari::newArray3D(anariDevice, ANARI_FLOAT32, dim[0], dim[1], dim[2]);
        {
          int totalSize = dim[0] * dim[1] * dim[2];
          auto brickDataPtr = anari::map<float>(anariDevice, brickData);

          for(vtkIdType i=0; i<cellArray->GetNumberOfTuples() && i<totalSize; i++)
          {
            brickDataPtr[i] = static_cast<float>(cellArray->GetTuple(i)[0]);
          }

          anari::unmap(anariDevice, brickData);
        }

        brickDataVector.emplace_back(brickData);
        blockLevelVector.emplace_back(level);

        auto cw = 1 << (amr->GetNumberOfLevels() - level - 1);
        ivec3 cellsLo = { lo[0] * cw, lo[1] * cw, lo[2] * cw };
        ivec3 cellsHi = { (hi[0]+1) * cw, (hi[1]+1) * cw, (hi[2]+1) * cw };
        box3i cellsBounds = { cellsLo, cellsHi };
        voxelSpaceBounds.extend(cellsBounds);
      }

      // Grid Origin
      double bounds[6];
      amr->GetBounds(bounds);
      vec3 gridOrigin = { static_cast<float>(bounds[0]),
                          static_cast<float>(bounds[2]),
                          static_cast<float>(bounds[4]) };
      anari::setParameter(anariDevice, anariSpatialField, "origin", gridOrigin);

      // Grid Spacing
      double spacing[3];
      amrInfo->GetSpacing(0, spacing);
      vec3 gridSpacing = { static_cast<float>(spacing[0]),
                           static_cast<float>(spacing[1]),
                           static_cast<float>(spacing[2]) };
      anari::setParameter(anariDevice, anariSpatialField, "spacing", gridSpacing);

      // Block Start
      anari::setParameterArray1D(anariDevice, anariSpatialField, "block.start", ANARI_INT32_VEC3,
                                blockStartVector.data(), blockStartVector.size());

      // Block Level
      anari::setParameterArray1D(anariDevice, anariSpatialField, "block.level", ANARI_INT32,
                                blockLevelVector.data(), blockLevelVector.size());
      vtkDebugMacro(<< "[ANARI::AMR] block level count: " << blockLevelVector.size());

      // Refinement Ratio
      std::vector<unsigned int> refinementRatioVector;

      if (!amrInfo->HasRefinementRatio())
      {
        amrInfo->GenerateRefinementRatio();
      }

      for (unsigned int i = 0; i < amr->GetNumberOfLevels(); i++)
      {
        refinementRatioVector.emplace_back(amrInfo->GetRefinementRatio(i));
      }

      anari::setParameterArray1D(anariDevice, anariSpatialField, "refinementRatio", ANARI_UINT32,
                                refinementRatioVector.data(), refinementRatioVector.size());

      // Block Data
      anari::setParameterArray1D(anariDevice, anariSpatialField, "block.data", ANARI_ARRAY3D,
                                brickDataVector.data(), brickDataVector.size());
      vtkDebugMacro(<< "[ANARI::AMR] block data count: " << brickDataVector.size());

      for (auto brickData : brickDataVector)
      {
        anari::release(anariDevice, brickData);
      }

      anari::commitParameters(anariDevice, anariSpatialField);
      anari::setAndReleaseParameter(anariDevice, anariVolume, "value", anariSpatialField);
      anari::commitParameters(anariDevice, anariVolume);
    }

    // Transfer function
    if (volProperty->GetMTime() > this->Internal->PropertyTime ||
        amr->GetMTime() > this->Internal->BuildTime)
    {
      this->Internal->UpdateTransferFunction(vol);
      anari_amr::TransferFunction* transferFunction = this->Internal->TransferFunction.get();

      anariSetParameter(anariDevice, anariVolume, "valueRange", ANARI_FLOAT32_BOX1, transferFunction->valueRange);

      auto array1DColor = anari::newArray1D(
        anariDevice, transferFunction->color.data(), transferFunction->color.size());
      anari::setAndReleaseParameter(anariDevice, anariVolume, "color", array1DColor);

      auto array1DOpacity = anari::newArray1D(
        anariDevice, transferFunction->opacity.data(), transferFunction->opacity.size());
      anari::setAndReleaseParameter(anariDevice, anariVolume, "opacity", array1DOpacity);

      anari::commitParameters(anariDevice, anariVolume);
      this->Internal->PropertyTime.Modified();
    }
  }
  else
  {
    vtkDebugMacro(<< "Volume visibility off");

    if (this->Internal->AnariVolume != nullptr)
    {
      this->Internal->AnariRendererNode =
        static_cast<vtkAnariRendererNode*>(this->GetFirstAncestorOfType("vtkAnariRendererNode"));
      auto anariDevice = this->Internal->AnariRendererNode->GetAnariDevice();
      anari::release(anariDevice, this->Internal->AnariVolume);
      this->Internal->AnariVolume = nullptr;
    }
    else
    {
      return;
    }
  }

  this->RenderTime = volNode->GetMTime();
  this->Internal->BuildTime.Modified();
}

//------------------------------------------------------------------------------
void vtkAnariAMRVolumeMapperNode::Render(bool prepass)
{
  vtkAnariProfiling startProfiling("vtkAnariAMRVolumeMapperNode::Render", vtkAnariProfiling::GREEN);

  if (!prepass)
  {
    return;
  }

  this->Internal->StageVolume();
}

VTK_ABI_NAMESPACE_END
