//
//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#include <iostream>
#include <random>
#include <vtkm/cont/DataSet.h>
#include <vtkm/cont/DataSetBuilderUniform.h>
#include <vtkm/cont/testing/Testing.h>
#include <vtkm/filter/uncertainty/ContourUncertainUniform.h>
#include <vtkm/filter/uncertainty/UniformMC.h>
#include <vtkm/io/VTKDataSetReader.h>
#include <vtkm/Math.h>
#include <vtkm/io/VTKDataSetWriter.h>

namespace
{
template <typename T>
vtkm::cont::DataSet MakeContourUncertainUniformTestDataSet()
{
  const vtkm::Id3 dims(25, 25, 25);

  vtkm::Id numPoints = dims[0] * dims[1] * dims[2];
  vtkm::cont::DataSetBuilderUniform dataSetBuilder;
  vtkm::cont::DataSet dataSet = dataSetBuilder.Create(dims);

  std::vector<T> ensemble_max;
  std::vector<T> ensemble_min;
  std::random_device rd;
  std::mt19937 gen(rd());

  std::uniform_real_distribution<vtkm::FloatDefault> minValue(-20, 20);
  std::uniform_real_distribution<vtkm::FloatDefault> maxValue(-20, 20);
  for (vtkm::Id i = 0; i < numPoints; ++i)
  {
    double value1 = minValue(gen);
    double value2 = maxValue(gen);
    ensemble_max.push_back(static_cast<T>(vtkm::Max(value1, value2)));
    ensemble_min.push_back(static_cast<T>(vtkm::Min(value1, value2)));
  }

  dataSet.AddPointField("ensemble_max", ensemble_max);
  dataSet.AddPointField("ensemble_min", ensemble_min);
  return dataSet;
}

void TestUncertaintyGeneral(vtkm::FloatDefault isoValue)
{

  vtkm::cont::DataSet input = MakeContourUncertainUniformTestDataSet<vtkm::FloatDefault>();

  vtkm::filter::uncertainty::ContourUncertainUniform filter;
  filter.SetIsoValue(isoValue);
  filter.SetCrossProbabilityName("CrossProbablity");
  filter.SetNumberNonzeroProbabilityName("NonzeroProbablity");
  filter.SetEntropyName("Entropy");
  filter.SetMinField("ensemble_min");
  filter.SetMaxField("ensemble_max");
  vtkm::cont::DataSet output = filter.Execute(input);

  vtkm::filter::uncertainty::UniformMC filter_mc;
  filter_mc.SetIsoValue(isoValue);
  // filter_mc.SetIterValue(0);
  filter_mc.SetCrossProbabilityName("CrossProbablityMC");
  filter_mc.SetNumberNonzeroProbabilityName("NonzeroProbablityMC");
  filter_mc.SetEntropyName("EntropyMC");
  filter_mc.SetMinField("ensemble_min");
  filter_mc.SetMaxField("ensemble_max");
  vtkm::cont::DataSet output_mc = filter_mc.Execute(input);


  // Closed Form
  vtkm::cont::Field CrossProb = output.GetField("CrossProbablity");
  vtkm::cont::ArrayHandle<vtkm::FloatDefault> crossProbArray;
  CrossProb.GetData().AsArrayHandle(crossProbArray);
  vtkm::cont::ArrayHandle<vtkm::FloatDefault>::ReadPortalType CrossPortal = crossProbArray.ReadPortal();

  vtkm::cont::Field NonzeroProb = output.GetField("NonzeroProbablity");
  vtkm::cont::ArrayHandle<vtkm::Id> NonzeroProbArray;
  NonzeroProb.GetData().AsArrayHandle(NonzeroProbArray);
  vtkm::cont::ArrayHandle<vtkm::Id>::ReadPortalType NonzeroPortal = NonzeroProbArray.ReadPortal();

  vtkm::cont::Field entropy = output.GetField("Entropy");
  vtkm::cont::ArrayHandle<vtkm::FloatDefault> EntropyArray;
  entropy.GetData().AsArrayHandle(EntropyArray);
  vtkm::cont::ArrayHandle<vtkm::FloatDefault>::ReadPortalType EntropyPortal = EntropyArray.ReadPortal();

  // MC
  vtkm::cont::Field CrossProbMC = output_mc.GetField("CrossProbablityMC");
  vtkm::cont::ArrayHandle<vtkm::FloatDefault> crossProbMCArray;
  CrossProbMC.GetData().AsArrayHandle(crossProbMCArray);
  vtkm::cont::ArrayHandle<vtkm::FloatDefault>::ReadPortalType CrossMCPortal = crossProbMCArray.ReadPortal();
  vtkm::cont::Field NonzeroMCProb = output_mc.GetField("NonzeroProbablityMC");
  vtkm::cont::ArrayHandle<vtkm::FloatDefault> NonzeroProbMCArray;
  NonzeroMCProb.GetData().AsArrayHandle(NonzeroProbMCArray);
  vtkm::cont::ArrayHandle<vtkm::FloatDefault>::ReadPortalType NonzeroMCPortal = NonzeroProbMCArray.ReadPortal();

  vtkm::cont::Field entropyMC = output_mc.GetField("EntropyMC");
  vtkm::cont::ArrayHandle<vtkm::FloatDefault> EntropyMCArray;
  entropyMC.GetData().AsArrayHandle(EntropyMCArray);
  vtkm::cont::ArrayHandle<vtkm::FloatDefault>::ReadPortalType EntropyMCPortal = EntropyMCArray.ReadPortal();

  for (vtkm::Id i = 0; i < crossProbArray.GetNumberOfValues(); ++i)
  {
      vtkm::FloatDefault CrossProbValue = CrossPortal.Get(i);
      vtkm::Id NonzeroProbValue = NonzeroPortal.Get(i);
      vtkm::FloatDefault EntropyValue = EntropyPortal.Get(i);

      // std::cout << CrossProbValue << ' ' << NonzeroProbValue << ' ' << EntropyValue << std::endl;

      vtkm::FloatDefault CrossProbMCValue = CrossMCPortal.Get(i);
      vtkm::FloatDefault NonzeroProbMCValue = NonzeroMCPortal.Get(i);
      vtkm::FloatDefault EntropyMCValue = EntropyMCPortal.Get(i);

      // std::cout << CrossProbMCValue << ' ' << NonzeroProbMCValue << ' ' << EntropyMCValue << std::endl;
      if ((vtkm::Abs(CrossProbMCValue - CrossProbValue) > 0.05) || (vtkm::Abs(NonzeroProbMCValue - NonzeroProbValue) > 0.05) || (vtkm::Abs(EntropyMCValue - EntropyValue) > 0.05))
      {
          std::cout << "Failed" << std::endl;
      }

  }
}

void TestContourUncertainUniform()
{
  vtkm::FloatDefault isoValue = 0;
  TestUncertaintyGeneral(isoValue);
}
}
int UnitTestContourUncertainUniform(int argc, char* argv[])
{
  return vtkm::cont::testing::Testing::Run(TestContourUncertainUniform, argc, argv);
}
