/*=========================================================================

  Program:   Visualization Toolkit
  Module:    vtkmFeatureAnalysis.cxx

  Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
  All rights reserved.
  See Copyright.txt or http://www.kitware.com/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notice for more information.

=========================================================================*/
#include "vtkmFeatureAnalysis.h"

#include "vtkInformation.h"
#include "vtkInformationVector.h"
#include "vtkLogger.h"
#include "vtkMPIController.h"
#include "vtkPartitionedDataSet.h"
#include "vtkPolyData.h"
#include "vtkImageData.h"

#include <vtkm/cont/Initialize.h>
#include <vtkm/cont/DataSetBuilderUniform.h>
#include <vtkm/cont/ArrayRangeCompute.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/cont/ArrayPortalToIterators.h>
#include <vtkm/worklet/AverageByKey.h>
#include <vtkm/cont/DataSetFieldAdd.h>
#include <vtkm/cont/EnvironmentTracker.h>
#include <vtkm/cont/Initialize.h>

#include <vtkm/cont/Timer.h>
#include <vtkm/cont/ArrayHandleSOA.h>
#include <vtkm/filter/ParticleDensityNearestGridPoint.h>

#include "vtkmlib/ArrayConverters.h"
#include "vtkmlib/PolyDataConverter.h"
#include "vtkmlib/ImageDataConverter.h"

#include <mpi.h>

/////////////////////////////////////////////////////////////////////
// Includes new filter and worklets developed for this algorithm
#include <vtkm/filter/SLIC.h>
#include <vtkm/filter/FieldGaussianSimilarity.h>



vtkStandardNewMacro(vtkmFeatureAnalysis);

//------------------------------------------------------------------------------
vtkmFeatureAnalysis::vtkmFeatureAnalysis() :
  FeatureGaussian{2.5, 10.0},
  ClusterBlockSize{8, 8, 8},
  HistDims{128, 16, 128}
{
  this->Controller =
    vtkMPIController::SafeDownCast(vtkMultiProcessController::GetGlobalController());
}

//------------------------------------------------------------------------------
vtkmFeatureAnalysis::~vtkmFeatureAnalysis() {}

//------------------------------------------------------------------------------
void vtkmFeatureAnalysis::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
}

//------------------------------------------------------------------------------
int vtkmFeatureAnalysis::FillOutputPortInformation(int vtkNotUsed(port),
                                                   vtkInformation* info)
{
  // now add our info
  info->Set(vtkDataObject::DATA_TYPE_NAME(), "vtkPartitionedDataSet");
  return 1;
}

//------------------------------------------------------------------------------
int vtkmFeatureAnalysis::RequestData(
  vtkInformation* request, vtkInformationVector** inputVector,
  vtkInformationVector* outputVector)
{
  if (! this->Controller)
  {
    vtkErrorMacro(<< "MPIController is not initialized");
    return 0;
  }
  int world_size = this->Controller->GetNumberOfProcesses(),
    rank = this->Controller->GetLocalProcessId();

  // Collect information:
  vtkInformation* inInfo = inputVector[0]->GetInformationObject(0);
  vtkInformation* outInfo = outputVector->GetInformationObject(0);

  vtkPolyData* input = vtkPolyData::GetData(inInfo);
  //vtkLog(INFO, "rank: " << rank << " number of points: " << input->GetNumberOfPoints());
  vtkPartitionedDataSet* output =
    vtkPartitionedDataSet::SafeDownCast(outInfo->Get(
                                          vtkDataObject::DATA_OBJECT()));
  output->SetNumberOfPartitions(1);

  //// how to query dimenison so we can get xdim,ydim,zdim from data set?
  vtkm::Id xdim=this->HistDims[0];
  vtkm::Id ydim=this->HistDims[1];
  vtkm::Id zdim=this->HistDims[2];
  vtkm::cont::DataSet data_local = tovtkm::Convert(input, tovtkm::FieldsFlag::Points);

  ///////////////////////////////////////////////////////////////////////
  //1. Compute density field using MPI distributed communication

  vtkm::cont::Timer timer_density;
  timer_density.Start();

  //First compute global particle bounds from local bounds
  double* lb = input->GetBounds();
  vtkm::Bounds local_bounds(lb[0], lb[1], lb[2], lb[3], lb[4], lb[5]);
  double min_bounds[3] = {local_bounds.X.Min, local_bounds.Y.Min, local_bounds.Z.Min};
  double max_bounds[3] = {local_bounds.X.Max, local_bounds.Y.Max, local_bounds.Z.Max};
  double global_min_bounds[3],global_max_bounds[3];
  this->Controller->AllReduce(min_bounds,global_min_bounds,3,vtkCommunicator::MIN_OP);
  this->Controller->AllReduce(max_bounds,global_max_bounds,3,vtkCommunicator::MAX_OP);

  //Global bound
  vtkm::Bounds bounds;
  bounds.X.Min = global_min_bounds[0];
  bounds.X.Max = global_max_bounds[0];
  bounds.Y.Min = global_min_bounds[1];
  bounds.Y.Max = global_max_bounds[1];
  bounds.Z.Min = global_min_bounds[2];
  bounds.Z.Max = global_max_bounds[2];

  //Execute density filter
  vtkm::filter::ParticleDensityNearestGridPoint particleDensity{
    vtkm::Id3{ xdim, ydim, zdim },
    bounds
  };
  particleDensity.SetComputeNumberDensity(true);
  particleDensity.SetDivideByVolume(false);

  auto local_density_field = particleDensity.Execute(data_local);

  //get the density field out to an vtkm arrayhandle
  vtkm::cont::ArrayHandle<vtkm::Float32> vtkm_density_arr_local;
  local_density_field.GetCellField("density").GetData().AsArrayHandle<vtkm::Float32>(vtkm_density_arr_local);

  //Get the raw C pointer so we can use MPI
  const float* c_arr_local_density = vtkm::cont::ArrayHandleBasic<vtkm::Float32>(vtkm_density_arr_local).GetReadPointer();
  float *c_arr_global_density;
  c_arr_global_density = (float *)malloc(xdim*ydim*zdim*sizeof(float));
  // Reduce the local field into a global field
  this->Controller->Reduce(c_arr_local_density,c_arr_global_density,xdim*ydim*zdim,vtkCommunicator::SUM_OP,0);

  timer_density.Stop();
  vtkm::Float64 elapsedTime = timer_density.GetElapsedTime();
  //std::cout << "At Rank: "<<rank<<" Density field is created and time taken: " << elapsedTime <<std::endl;


  double global_density_time;
  this->Controller->Reduce(&elapsedTime, &global_density_time, 1, vtkCommunicator::MAX_OP,0);

  //Rest of the algorithm is run on a single MPI process
  if (rank==0)
  {

    std::cout << "Max Density time: "<<global_density_time<<" secs" <<std::endl;

    //Parameter:: SLIC cluster size
    vtkm::Id blockXSize=this->ClusterBlockSize[0];
    vtkm::Id blockYSize=this->ClusterBlockSize[1];
    vtkm::Id blockZSize=this->ClusterBlockSize[2];
    //Parameter:: SLIC weight in distance function
    vtkm::Float64 weight=0.1;
    //Parameter:: SLIC halt condition
    vtkm::Float64 halt_cond=0.2;
    //Parameter:: SLIC iteration limit
    vtkm::Id iter_limit=50;
    //Parameter:: Target feature distribution to be searched.
    vtkm::Pair<vtkm::Float32,vtkm::Float32> feature_Gauss = vtkm::make_Pair(
      this->FeatureGaussian[0], this->FeatureGaussian[1]);

    //Create a uniform grid dataset from global density C array so it can be passed to SLIC filter.
    vtkm::Vec<vtkm::Float64, 3> origin{bounds.X.Min, bounds.Y.Min, bounds.Z.Min};
    vtkm::Vec<vtkm::Float64, 3> spacing{
      (bounds.X.Max - bounds.X.Min) / vtkm::Float64(xdim-1),
      (bounds.Y.Max - bounds.Y.Min) / vtkm::Float64(ydim-1),
      (bounds.Z.Max - bounds.Z.Min) / vtkm::Float64(zdim-1)};

    auto density_field = vtkm::cont::DataSetBuilderUniform::Create(
      //vtkm::Id3{ xdim, ydim, zdim } + vtkm::Id3{ 1, 1, 1 }, origin, spacing);
      vtkm::Id3{ xdim, ydim, zdim }, origin, spacing);

    //Create array handle from the global density C array pointer
    vtkm::cont::ArrayHandle <vtkm::Float32 > density_arr_handle =
      vtkm::cont::make_ArrayHandle(c_arr_global_density, xdim*ydim*zdim, vtkm::CopyFlag::On);

    //Add density arrayhanlde to field
    //density_field.AddField(vtkm::cont::make_FieldCell("density", density_arr_handle));
    density_field.AddField(vtkm::cont::make_FieldPoint("density", density_arr_handle));

    // ///////////////////////////////////////////////////////////////////////
    // //Write density field output
    // std::stringstream ss;
    // ss<<rank;
    // std::string fname = "density_field.vtk";
    // vtkm::io::VTKDataSetWriter writer(fname);
    // writer.WriteDataSet(density_field);
    // ///////////////////////////////////////////////////////////////////////

    // //2. Compute slic
    vtkm::cont::Timer timer_slic;
    timer_slic.Start();

    std::string fieldname2 = "density";

    vtkm::filter::SLIC slic;
    slic.SetFieldDimension(vtkm::Id3(xdim,ydim,zdim));
    slic.SetInitClusterSize(vtkm::Id3(blockXSize,blockYSize,blockZSize));
    slic.SetWeight(weight);
    slic.SetHaltCond(halt_cond);
    slic.SetMaxIter(iter_limit);
    slic.SetSlicFieldName(fieldname2);
    slic.SetActiveField(fieldname2);
    vtkm::cont::DataSet outSlicField = slic.Execute(density_field);

    timer_slic.Stop();
    vtkm::Float64 elapsedTime1 = timer_slic.GetElapsedTime();
    std::cout << "Slic time: " << elapsedTime1 <<std::endl;

    // // ///////////////////////////////////////////////////////////////////////
    // // //Write final output field
    // // vtkm::io::VTKDataSetWriter writer1("slic_field.vtk");
    // // writer1.WriteDataSet(outSlicField);


    // ///////////////////////////////////////////////////////////////////////
    // //3. Compute statistical feature similarity field
    vtkm::cont::Timer timer_sim;
    timer_sim.Start();

    std::string fieldname1 = "ClusterIds";

    vtkm::filter::FieldGaussianSimilarity gsimilarity;
    gsimilarity.SetActiveField(fieldname1);
    gsimilarity.SetFieldNames(fieldname1,fieldname2);
    gsimilarity.SetFeatureGaussian(feature_Gauss);
    vtkm::cont::DataSet finalOutField = gsimilarity.Execute(outSlicField);

    timer_sim.Stop();
    vtkm::Float64 elapsedTime2 = timer_sim.GetElapsedTime();
    std::cout << "Similarity field generation time: " << elapsedTime2 <<std::endl;

    std::cout<< "Total time: "<<global_density_time + elapsedTime1 + elapsedTime2 <<std::endl;
    vtkNew<vtkImageData> out;
    if (! fromvtkm::Convert(finalOutField, out.Get(), out.Get()))
    {
      vtkErrorMacro(<< "Cannot convert vtkm::cont::DataSet to vtkImageData");
      return 0;
    }
    output->SetPartition(0, out);
  }
  return 1;
}
