#include "vtkComputeMoments.h"
#include "vtkMomentsHelper.h"
//#include "vtkMomentsTensor.h"

#include "vtkmlib/ImageDataConverter.h"
#include "vtkmlib/Storage.h"
#include "vtkmFilterPolicy.h"

#include <vtkm/filter/ComputeMoments.h>

#include <vtkDataArray.h>
#include <vtkDoubleArray.h>
#include <vtkImageData.h>
#include <vtkPointData.h>

#include <algorithm>
#include <vector>

namespace
{

struct MyPolicy : public vtkm::filter::PolicyBase<MyPolicy>
{
  using FieldTypeList = vtkm::ListTagBase<vtkm::Float32,
                                          vtkm::Float64,
                                          vtkm::Vec<vtkm::Float32, 2>,
                                          vtkm::Vec<vtkm::Float64, 2>,
                                          vtkm::Vec<vtkm::Float32, 3>,
                                          vtkm::Vec<vtkm::Float64, 3>,
                                          vtkm::Vec<vtkm::Float32, 4>,
                                          vtkm::Vec<vtkm::Float64, 4>,
                                          vtkm::Vec<vtkm::Float32, 6>,
                                          vtkm::Vec<vtkm::Float64, 6>,
                                          vtkm::Vec<vtkm::Float32, 9>,
                                          vtkm::Vec<vtkm::Float64, 9>
                                         >;

  using StructuredCellSetList = vtkmInputFilterPolicy::StructuredCellSetList;
  using UnstructuredCellSetList = vtkmInputFilterPolicy::UnstructuredCellSetList;
  using AllCellSetList = vtkmInputFilterPolicy::AllCellSetList;
};

struct ExtractComponentImpl
{
  template <typename T, typename S>
  void operator()(const vtkm::cont::ArrayHandle<T, S>& field,
                  const int *idx,
                  vtkDoubleArray* out) const
  {
    auto portal = field.GetPortalConstControl();
    auto numComps = vtkm::VecTraits<T>::GetNumberOfComponents(portal.Get(0));

    vtkm::IdComponent compIdx = 0;
    switch (numComps)
    {
      case 1:
        compIdx = 0;
        break;
      case 2: case 3:
        compIdx = static_cast<vtkm::IdComponent>(*idx);
        break;
      case 4:
        compIdx = static_cast<vtkm::IdComponent>(idx[1] * 2 + idx[0]);
        break;
      case 6: case 9:
        compIdx = static_cast<vtkm::IdComponent>(idx[0] * 3 + idx[1]);
        break;
      default:
        std::cout << "execution shouldn't reach here\n";
        abort();
    }

    for (vtkm::Id i = 0; i < portal.GetNumberOfValues(); ++i)
    {
      out->SetValue(i, static_cast<double>(vtkm::VecTraits<T>::GetComponent(portal.Get(i), compIdx)));
    }
  }
};

void ExtractComponent(const vtkm::cont::Field& field, const int* idx, vtkDataArray* out)
{
  vtkm::cont::CastAndCall(vtkm::filter::ApplyPolicy(field, MyPolicy{}),
                          ExtractComponentImpl{},
                          idx,
                          vtkDoubleArray::SafeDownCast(out));
}

} // anonymous namespace

void vtkComputeMoments::ComputeVtkm(
  int radiusIndex, vtkImageData* grid, vtkImageData* field, vtkImageData* output)
{
  std::cout << "vtkComputeMoments::ComputeVtkm \n";

  int gridDims[3];
  grid->GetDimensions(gridDims);
  const bool is2D = gridDims[2] == 1;

  if (grid != field)
  {
    int fieldDims[3];
    double gridSpacing[3], fieldSpacing[3];
    grid->GetSpacing(gridSpacing);
    field->GetDimensions(fieldDims);
    field->GetSpacing(fieldSpacing);

    for (int i = 0; i < 3; ++i)
    {
      if (gridDims[i] != fieldDims[i] || gridSpacing[i] != fieldSpacing[i])
      {
        vtkErrorMacro(<< "The structure of grid and field must be the same for VTK-m");
        return;
      }
    }
  }

  vtkm::Vec<vtkm::Int32, 3> discreteRadius{0};
  for (int d = 0; d < this->Dimension; ++d)
  {
    discreteRadius[d] = this->Radii.at(radiusIndex) / (field->GetSpacing()[d] - 1e-10);
  }

  auto fieldArray = field->GetPointData()->GetArray(this->NameOfPointData.c_str());

  try
  {
    // convert the input dataset to a vtkm::cont::DataSet
    vtkm::cont::DataSet in = tovtkm::Convert(field);
    auto field = tovtkm::Convert(fieldArray, vtkDataObject::FIELD_ASSOCIATION_POINTS);
    in.AddField(field);
    //in.PrintSummary(std::cout);

    vtkm::filter::ComputeMoments computeMoments;
    computeMoments.SetOrder(this->Order);
    computeMoments.SetRadius(discreteRadius);
    computeMoments.SetActiveField(this->NameOfPointData);

    vtkm::cont::DataSet out = computeMoments.Execute(in, MyPolicy{});
    //out.PrintSummary(std::cout);

    std::vector<int> indices;
    for (int order = 0; order <= this->Order; ++order)
    {
      const int maxR = is2D ? 0 : order; // 2D grids don't use r
      for (int r = 0; r <= maxR; ++r)
      {
        const int qMax = order - r;
        for (int q = 0; q <= qMax; ++q)
        {
          const int p = order - r - q;

          indices.resize(order);

          // Fill indices according to pqr values:
          if (!indices.empty())
          {
            auto iter = indices.begin();
            iter = std::fill_n(iter, p, 0);
            iter = std::fill_n(iter, q, 1);
            iter = std::fill_n(iter, r, 2);
            assert(iter == indices.end());
          }

          auto vtkmFieldName = std::string("index");
          for (int i : indices)
          {
            vtkmFieldName += std::to_string(i);
          }

//          std::cerr << "Order: " << order << " "
//                    << "pqr: " << p << "x" << q << "x" << r << " "
//                    << "Field name: " << vtkmFieldName << "\n";

          auto field = out.GetField(vtkmFieldName);

          auto numComps = static_cast<int>(std::pow(this->Dimension, this->FieldRank));
          for (int c = 0; c < numComps; ++c)
          {
            indices.resize(order);
            for (int i = 0; i < this->FieldRank; ++i)
            {
              indices.push_back((c / static_cast<int>(std::pow(this->Dimension, i))) % this->Dimension);
            }

            auto vtkFieldName =
                vtkMomentsHelper::getFieldNameFromTensorIndices(this->Radii.at(radiusIndex),
                                                                indices,
                                                                this->FieldRank);
            ExtractComponent(field,
                             indices.data() + order,
                             output->GetPointData()->GetArray(vtkFieldName.c_str()));
          }
        }
      }
    }
  }
  catch (const vtkm::cont::Error& e)
  {
    vtkErrorMacro(<< "VTK-m error: " << e.GetMessage());
    return;
  }
}
