//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#include <vtkm/cont/ArrayCopy.h>
#include <vtkm/cont/DataSet.h>
#include <vtkm/cont/Timer.h>
#include <vtkm/filter/uncertainty/UniformMC.h>
#include <vtkm/worklet/WorkletMapTopology.h>

//#ifdef VTKM_CUDA
#if  defined(VTKM_KOKKOS_CUDA) || defined(VTKM_CUDA)
#include <thrust/device_vector.h>
#include <thrust/random/linear_congruential_engine.h>
#include <thrust/random/uniform_real_distribution.h>
#else
#include <random>
#endif

namespace vtkm
{
namespace worklet
{
namespace uniform
{
class UniformMC : public vtkm::worklet::WorkletVisitCellsWithPoints
{

public:
  UniformMC(double isovalue)
    : m_isovalue(isovalue){}

  using ControlSignature =
    void(CellSetIn, FieldInPoint, FieldInPoint, FieldOutCell, FieldOutCell, FieldOutCell);
  using ExecutionSignature = void(_2, _3, _4, _5, _6);
  using InputDomain = _1;

  template <typename InPointFieldMinType,
            typename InPointFieldMaxType,
            typename OutCellFieldType1,
            typename OutCellFieldType2,
            typename OutCellFieldType3>

  VTKM_EXEC void operator()(const InPointFieldMinType& inPointFieldVecMin,
                            const InPointFieldMaxType& inPointFieldVecMax,
                            OutCellFieldType1& outNonCrossProb,
                            OutCellFieldType2& outCrossProb,
                            OutCellFieldType3& outEntropyProb) const
  {
    vtkm::IdComponent numPoints = inPointFieldVecMin.GetNumberOfComponents();

    if (numPoints != 8)
    {
      printf("this is the 3d version for 8 vertices\n");
      return;
    }

    vtkm::FloatDefault minV = 0.0;
    vtkm::FloatDefault maxV = 0.0;
    vtkm::FloatDefault uniformDistValue = 0.0;
    vtkm::FloatDefault numSample = 2;
    vtkm::FloatDefault numCrossing = 0;
    //vtkm::FloatDefault nonCrossProb = 0;
    vtkm::FloatDefault crossProb = 0;

    vtkm::IdComponent zeroFlag;
    vtkm::IdComponent oneFlag;
    vtkm::Float64 base = 2.0;
    vtkm::Float64 totalSum = 0.0;
    vtkm::IdComponent nonZeroCase = 0;

    vtkm::FloatDefault entropyValue = 0;
    vtkm::FloatDefault templog = 0;
    vtkm::FloatDefault value = 0.0;

    vtkm::Vec<vtkm::IdComponent, 256> probHistogram;

    for (vtkm::IdComponent k = 0; k < 256; k++)
    {
      probHistogram[k] = 0;
    }


#if defined(VTKM_CUDA) || defined(VTKM_CUDA)
    thrust::minstd_rand rng;
    for (vtkm::IdComponent i = 0; i < numSample; ++i)
    {
      zeroFlag = 0;
      oneFlag = 0;
      totalSum = 0.0;
      //numTime = 0;
      for (vtkm::IdComponent pointIndex = 0; pointIndex < numPoints; ++pointIndex)
      {
        minV = static_cast<vtkm::FloatDefault>(inPointFieldVecMin[pointIndex]);
        maxV = static_cast<vtkm::FloatDefault>(inPointFieldVecMax[pointIndex]);
        thrust::uniform_real_distribution<vtkm::FloatDefault> genNum(minV, maxV);
        uniformDistValue = genNum(rng);

        if (uniformDistValue <= this->m_isovalue) // 0 <- Zeroflag 1 Oneflag 1
        {
          zeroFlag = 1; // Detected zero
        }
        else
        {
          oneFlag = 1;                             // Detected one
          totalSum += vtkm::Pow(base, pointIndex); // From binary to decimal
        }
      }

      if ((oneFlag == 1) and (zeroFlag == 1))
      {
        numCrossing += 1;
      }

      if ((totalSum >= 0) and (totalSum <= 255))
      {
        probHistogram[totalSum] += 1;
      }
    }

    for (vtkm::IdComponent i = 0; i < 256; i++)
    {
      templog = 0;
      value = static_cast<vtkm::FloatDefault>(probHistogram[i] / numSample);
      if (probHistogram[i] > 0.00001)
      {
        nonZeroCase++;
        templog = vtkm::Log2(value);
      }
      entropyValue = entropyValue + (-value) * templog;
    }

#else
    std::random_device rd;
    std::mt19937 gen(rd());
    for (vtkm::IdComponent i = 0; i < numSample; ++i)
    {
      zeroFlag = 0;
      oneFlag = 0;
      totalSum = 0.0;
      //numTime = 0;
      for (vtkm::IdComponent pointIndex = 0; pointIndex < numPoints; ++pointIndex)
      {
        minV = static_cast<vtkm::FloatDefault>(inPointFieldVecMin[pointIndex]);
        maxV = static_cast<vtkm::FloatDefault>(inPointFieldVecMax[pointIndex]);
        std::uniform_real_distribution<vtkm::FloatDefault> genNum(minV, maxV);
        uniformDistValue = genNum(gen);

        if (uniformDistValue <= this->m_isovalue) // 0 <- Zeroflag 1 Oneflag 1
        {
          zeroFlag = 1; // Detected zero
        }
        else
        {
          oneFlag = 1;                             // Detected one
          totalSum += vtkm::Pow(base, pointIndex); // From binary to decimal
        }
      }

      if ((oneFlag == 1) && (zeroFlag == 1))
      {
        numCrossing += 1;
      }

      if ((totalSum >= 0) && (totalSum <= 255))
      {
        probHistogram[totalSum] += 1;
      }
    }

    for (vtkm::IdComponent i = 0; i < 256; i++)
    {
      templog = 0;
      value = static_cast<vtkm::FloatDefault>(probHistogram[i] / numSample);
      if (probHistogram[i] > 0.00001)
      {
        nonZeroCase++;
        templog = vtkm::Log2(value);
      }
      entropyValue = entropyValue + (-value) * templog;
    }
#endif


    crossProb = numCrossing / numSample;
    //nonCrossProb = 1 - crossProb;
    outNonCrossProb = nonZeroCase;
    outCrossProb = crossProb;
    outEntropyProb = entropyValue;
  }

private:
  double m_isovalue;
};
}
}
}

namespace vtkm
{
namespace filter
{
namespace uncertainty
{
UniformMC::UniformMC()
{
  this->SetCrossProbabilityName("cross_probability");
}
VTKM_CONT vtkm::cont::DataSet UniformMC::DoExecute(const vtkm::cont::DataSet& input)
{
  vtkm::cont::Field minField = this->GetFieldFromDataSet(0, input);
  vtkm::cont::Field maxField = this->GetFieldFromDataSet(1, input);

  vtkm::cont::UnknownArrayHandle crossProbability;
  vtkm::cont::UnknownArrayHandle nonCrossProbability;
  vtkm::cont::UnknownArrayHandle entropyProbability; // Add

  if (!input.GetCellSet().IsType<vtkm::cont::CellSetStructured<3>>())
  {
    throw vtkm::cont::ErrorBadType("Uncertain contour only works for CellSetStructured<3>.");
  }
  vtkm::cont::CellSetStructured<3> cellSet;
  input.GetCellSet().AsCellSet(cellSet);

  auto resolveType = [&](auto concreteMinField) {
    using ArrayType = std::decay_t<decltype(concreteMinField)>;
    using ValueType = typename ArrayType::ValueType;
    ArrayType concreteMaxField;
    vtkm::cont::ArrayCopyShallowIfPossible(maxField.GetData(), concreteMaxField);

    vtkm::cont::ArrayHandle<ValueType> concreteCrossProb;
    vtkm::cont::ArrayHandle<ValueType> concreteNonCrossProb;
    vtkm::cont::ArrayHandle<ValueType> concreteEntropyProb; // Add


    this->Invoke(vtkm::worklet::uniform::UniformMC{ this->IsoValue },
                 cellSet,
                 concreteMinField,
                 concreteMaxField,
                 concreteNonCrossProb,
                 concreteCrossProb,
                 concreteEntropyProb);
    crossProbability = concreteCrossProb;
    nonCrossProbability = concreteNonCrossProb;
    entropyProbability = concreteEntropyProb; // Add
  };
  this->CastAndCallScalarField(minField, resolveType);

  vtkm::cont::DataSet result = this->CreateResult(input);
  result.AddCellField(this->GetCrossProbabilityName(), crossProbability);
  result.AddCellField(this->GetNumberNonzeroProbabilityName(), nonCrossProbability);
  result.AddCellField(this->GetEntropyName(), entropyProbability); //Add
  return result;
}
}
}
}
