//=============================================================================
//
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
//  Copyright 2016 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
//  Copyright 2016 UT-Battelle, LLC.
//  Copyright 2016 Los Alamos National Security.
//
//  Under the terms of Contract DE-NA0003525 with NTESS,
//  the U.S. Government retains certain rights in this software.
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//
//=============================================================================
#include <vtkm/rendering/raytracing/MinMaxVoxelGrid.h>

#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/rendering/raytracing/BVHTraverser.h>
#include <vtkm/rendering/raytracing/RayTracingTypeDefs.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/DispatcherMapTopology.h>
#include <vtkm/worklet/WorkletMapField.h>
#include <vtkm/worklet/WorkletMapTopology.h>

namespace vtkm
{
namespace rendering
{
namespace raytracing
{

union Flint32
{
  vtkm::Float32 Float;
  vtkm::Int32 Int;
};

class AtomicIntersectBins : public vtkm::worklet::WorkletMapPointToCell
{
protected:
  vtkm::Id3 Dims;
  vtkm::Vec<vtkm::Float32, 3> Origin;
  vtkm::Vec<vtkm::Float32, 3> Spacing;
  vtkm::Float32 MinScalar;
  vtkm::Float32 InverseDeltaScalar;

public:
  using ControlSignature = void(CellSetIn cellset,
                                WholeArrayIn,
                                WholeArrayIn,
                                AtomicArrayInOut,
                                AtomicArrayInOut,
                                WholeArrayIn);
  using ExecutionSignature = void(PointIndices, PointCount, _2, _3, _4, _5, _6);

  VTKM_CONT
  AtomicIntersectBins(vtkm::Id3 dims,
                      vtkm::Vec<vtkm::Float32, 3> origin,
                      vtkm::Vec<vtkm::Float32, 3> spacing,
                      vtkm::Float32 minScalar,
                      vtkm::Float32 maxScalar)
    : Dims(dims)
    , Origin(origin)
    , Spacing(spacing)
    , MinScalar(minScalar)
  {
    if ((maxScalar - minScalar) != 0.f)
      InverseDeltaScalar = 1.f / (maxScalar - minScalar);
    else
      InverseDeltaScalar = minScalar;
  }

  template <typename AtomicType>
  VTKM_EXEC void Min(AtomicType& atomic, const vtkm::Float32& val, const vtkm::Id& index) const
  {
    Flint32 fval;
    fval.Float = val;
    Flint32 old;
    old.Float = vtkm::Infinity32();
    do
    {
      old.Int = atomic.CompareAndSwap(index, fval.Int, old.Int);
    } while (old.Float > fval.Float);
  }

  template <typename AtomicType>
  VTKM_EXEC void Max(AtomicType& atomic, const vtkm::Float32& val, const vtkm::Id& index) const
  {
    Flint32 fval;
    fval.Float = val;
    Flint32 old;
    old.Float = vtkm::NegativeInfinity32();
    do
    {
      old.Int = atomic.CompareAndSwap(index, fval.Int, old.Int);
    } while (old.Float < fval.Float);
  }

  template <typename PointVecType,
            typename CoordsType,
            typename ScalarType,
            typename AtomicPortalType,
            typename ColorType>
  VTKM_EXEC void operator()(const PointVecType& pointIndices,
                            const vtkm::IdComponent& numPoints,
                            const CoordsType& coords,
                            const ScalarType& scalars,
                            const AtomicPortalType& mins,
                            const AtomicPortalType& maxs,
                            const ColorType& colorMap) const
  {
    vtkm::Bounds bounds;
    vtkm::Float32 vmin = vtkm::Infinity32();
    vtkm::Float32 vmax = vtkm::NegativeInfinity32();

    const vtkm::Id colorMapSize = colorMap.GetNumberOfValues() - 1;

    for (vtkm::Int32 i = 0; i < vtkm::Int32(numPoints); ++i)
    {
      bounds.Include(coords.Get(pointIndices[i]));

      vtkm::Float32 value = vtkm::Float32(scalars.Get(pointIndices[i]));


      vmin = vtkm::Min(vmin, value);
      vmax = vtkm::Max(vmax, value);
    }
    // we have the min and max for the cell, but the color map values
    // vary by some other function. Thus, we have to walk the ranges included
    // in the cell inside the color map to know the true density range
    vmax = (vmax - MinScalar) * InverseDeltaScalar;
    vmin = (vmin - MinScalar) * InverseDeltaScalar;

    vtkm::Id maxIndex = static_cast<vtkm::Id>(vmax * static_cast<vtkm::Float32>(colorMapSize));
    maxIndex = vtkm::Max(vtkm::Id(0), vtkm::Min(colorMapSize, maxIndex));

    vtkm::Id minIndex = static_cast<vtkm::Id>(vmin * static_cast<vtkm::Float32>(colorMapSize));
    minIndex = vtkm::Max(vtkm::Id(0), vtkm::Min(colorMapSize, minIndex));

    vmin = vtkm::Infinity32();
    vmax = vtkm::NegativeInfinity32();
    for (vtkm::Id i = minIndex; i <= maxIndex; ++i)
    {
      vtkm::Float32 value = colorMap.Get(i)[3];
      vmin = vtkm::Min(vmin, value);
      vmax = vtkm::Max(vmax, value);
    }

    vtkm::Id3 minbox, maxbox;


    minbox[0] = static_cast<vtkm::Id>((bounds.X.Min - Origin[0]) / Spacing[0]);
    minbox[1] = static_cast<vtkm::Id>((bounds.Y.Min - Origin[1]) / Spacing[1]);
    minbox[2] = static_cast<vtkm::Id>((bounds.Z.Min - Origin[2]) / Spacing[2]);

    maxbox[0] = static_cast<vtkm::Id>((bounds.X.Max - Origin[0]) / Spacing[0]);
    maxbox[1] = static_cast<vtkm::Id>((bounds.Y.Max - Origin[1]) / Spacing[1]);
    maxbox[2] = static_cast<vtkm::Id>((bounds.Z.Max - Origin[2]) / Spacing[2]);
    //if(minbox[0] == 7 && minbox[1]==2 && minbox

    for (vtkm::Id z = minbox[2]; z <= maxbox[2]; ++z)
    {
      for (vtkm::Id y = minbox[1]; y <= maxbox[1]; ++y)
      {
        for (vtkm::Id x = minbox[0]; x <= maxbox[0]; ++x)
        {
          vtkm::Id idx = x + Dims[0] * (y + Dims[1] * z);
          Min(mins, vmin, idx);
          Max(maxs, vmax, idx);
        }
      }
    }
  }
};

class MinMaxCast : public vtkm::worklet::WorkletMapField
{
public:
  VTKM_CONT
  MinMaxCast() {}
  using ControlSignature = void(FieldIn, FieldOut);
  using ExecutionSignature = void(_1, _2);

  VTKM_EXEC
  void operator()(const vtkm::Int32& in, vtkm::Float32& out) const
  {
    Flint32 flint;
    flint.Int = in;
    out = flint.Float;
    //std::cout<<" out "<<out<<"\n";
  }
}; //class

class TranslateDensity : public vtkm::worklet::WorkletMapField
{
protected:
  vtkm::Float32 MinScalar;
  vtkm::Float32 InverseDeltaScalar;

public:
  VTKM_CONT
  TranslateDensity(vtkm::Float32 minScalar, vtkm::Float32 maxScalar)
    : MinScalar(minScalar)
  {
    //std::cout<<"******* Min "<<minScalar<<" max "<<maxScalar<<"\n";
    if ((maxScalar - minScalar) != 0.f)
      InverseDeltaScalar = 1.f / (maxScalar - minScalar);
    else
      InverseDeltaScalar = minScalar;
  }

  using ControlSignature = void(FieldInOut, FieldInOut, WholeArrayIn);
  using ExecutionSignature = void(_1, _2, _3);

  template <typename ColorPortal>
  VTKM_EXEC void operator()(vtkm::Float32& minValue,
                            vtkm::Float32& maxValue,
                            const ColorPortal& colorMap) const
  {
    //std::cout<<"min "<<minValue<<" max "<<maxValue;
    // normalize scalars
    minValue = (minValue - MinScalar) * InverseDeltaScalar;
    maxValue = (maxValue - MinScalar) * InverseDeltaScalar;
    //std::cout<<" normal min "<<minValue<<" max "<<maxValue;
    const vtkm::Id colorMapSize = colorMap.GetNumberOfValues();

    vtkm::Id minIndex = static_cast<vtkm::Id>(minValue * static_cast<vtkm::Float32>(colorMapSize));
    minIndex = vtkm::Max(vtkm::Id(0), vtkm::Min(colorMapSize, minIndex));

    vtkm::Id maxIndex = static_cast<vtkm::Id>(maxValue * static_cast<vtkm::Float32>(colorMapSize));
    maxIndex = vtkm::Max(vtkm::Id(0), vtkm::Min(colorMapSize, maxIndex));

    minValue = colorMap.Get(minIndex)[3];
    maxValue = colorMap.Get(maxIndex)[3];
    //std::cout<<" new min "<<minValue<<" max "<<maxValue<<"\n";;
    //std::cout<<"Min Index "<<minIndex<<" maxindex "<<maxIndex<<"\n";
  }
}; //class

void MinMaxVoxelGrid::Construct(
  const vtkm::cont::DynamicCellSet& cellset,
  const vtkm::cont::CoordinateSystem& coords,
  const vtkm::cont::Field& field,
  const vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float32, 4>>& colorMap,
  const vtkm::Range& scalarRange)
{
  vtkm::Vec<vtkm::Float32, 3> vdims;
  vdims[0] = static_cast<vtkm::Float32>(Dims[0]);
  vdims[1] = static_cast<vtkm::Float32>(Dims[1]);
  vdims[2] = static_cast<vtkm::Float32>(Dims[2]);

  vtkm::Bounds bounds = coords.GetBounds();
  vtkm::Vec<vtkm::Float32, 3> size, eps;
  size[0] = static_cast<vtkm::Float32>(bounds.X.Max - bounds.X.Min);
  size[1] = static_cast<vtkm::Float32>(bounds.Y.Max - bounds.Y.Min);
  size[2] = static_cast<vtkm::Float32>(bounds.Z.Max - bounds.Z.Min);
  // create conservative bounds
  eps[0] = vtkm::Max(1e-16f, size[0] * 1e-4f);
  eps[1] = vtkm::Max(1e-16f, size[1] * 1e-4f);
  eps[2] = vtkm::Max(1e-16f, size[2] * 1e-4f);
  size += 2.f * eps;

  Origin[0] = static_cast<vtkm::Float32>(bounds.X.Min);
  Origin[1] = static_cast<vtkm::Float32>(bounds.Y.Min);
  Origin[2] = static_cast<vtkm::Float32>(bounds.Z.Min);
  Origin -= eps;

  Spacing = size / vdims;

  vtkm::Id gridSize = Dims[0] * Dims[1] * Dims[2];
  std::cout << "Size " << size << " vdims " << vdims << "\n";
  std::cout << "Origin " << Origin << "\n";
  std::cout << "Spacing " << Spacing << "\n";
  std::cout << "eps " << eps << "\n";

  // vtkm does not support floating point atomics
  // so we have to be tricky and use compare and
  // swap operations with unions
  vtkm::cont::ArrayHandle<vtkm::Int32> iMinValues;
  vtkm::cont::ArrayHandle<vtkm::Int32> iMaxValues;

  Flint32 flint;

  flint.Float = vtkm::Infinity32();
  vtkm::cont::ArrayHandleConstant<vtkm::Int32> iInf(flint.Int, gridSize);
  vtkm::cont::Algorithm::Copy(iInf, iMinValues);

  flint.Float = vtkm::NegativeInfinity32();
  vtkm::cont::ArrayHandleConstant<vtkm::Int32> iNInf(flint.Int, gridSize);
  vtkm::cont::Algorithm::Copy(iNInf, iMaxValues);

  vtkm::worklet::DispatcherMapTopology<AtomicIntersectBins>(
    AtomicIntersectBins(
      Dims, Origin, Spacing, vtkm::Float32(scalarRange.Min), vtkm::Float32(scalarRange.Max)))
    .Invoke(cellset,
            coords,
            vtkm::rendering::raytracing::GetScalarFieldArray(field),
            iMinValues,
            iMaxValues,
            colorMap);

  vtkm::worklet::DispatcherMapField<MinMaxCast> castDispachter;
  castDispachter.Invoke(iMinValues, MinField);
  castDispachter.Invoke(iMaxValues, MaxField);



  auto maxPortal = MaxField.ReadPortal();
  vtkm::Id vid = 0 + 10 * 1 + 100 * 6;
  std::cout << "MAXXXX " << maxPortal.Get(vid) << "\n";
  //for(int i = 0; i < gridSize; ++i) std::cout<<"("<<i<<") "<<maxPortal.Get(i)<<" ";
  //std::cout<<"\n";

  vtkm::Float32 tmax =
    vtkm::cont::Algorithm::Reduce(MaxField, vtkm::NegativeInfinity32(), vtkm::Maximum());
  vtkm::Float32 tmin = vtkm::cont::Algorithm::Reduce(MinField, vtkm::Infinity32(), vtkm::Minimum());

  vtkm::Range orig = field.GetRange().GetPortalConstControl().Get(0);
  std::cout << "Input range " << orig << " output " << tmin << " - " << tmax << "\n";

  //vtkm::worklet::DispatcherMapField<TranslateDensity>
  //  translateDispatcher(TranslateDensity(vtkm::Float32(scalarRange.Min),
  //                                       vtkm::Float32(scalarRange.Max)));

  //translateDispatcher.Invoke(MinField, MaxField, colorMap);

  IsConstructed = true;
}

MinMaxVoxelGrid::MinMaxVoxelGrid()
  : Dims(64, 64, 64)
  , IsConstructed(false)
{
}

void MinMaxVoxelGrid::SetDims(vtkm::Id3 dims)
{
  Dims = dims;
}

vtkm::Id3 MinMaxVoxelGrid::GetDims()
{
  return Dims;
}

}
}
} //namespace vtkm::rendering::raytracing
