#include <vtkm/cont/ArrayHandleGroupVec.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vtkm/cont/CellSetSingleType.h>

#include <vtkm/exec/CellEdge.h>

#include <vtkm/worklet/DispatcherMapTopology.h>
#include <vtkm/worklet/ScatterCounting.h>
#include <vtkm/worklet/WorkletMapTopology.h>

#include <vtkm/filter/FilterDataSet.h>

#include <vtkm/cont/testing/MakeTestDataSet.h>
#include <vtkm/cont/testing/Testing.h>

namespace vtkm {
namespace worklet {

namespace {

struct ExtractEdges
{
  ////
  //// BEGIN-EXAMPLE GenerateMeshConstantShapeCount.cxx
  ////
  struct CountEdges : vtkm::worklet::WorkletMapPointToCell
  {
    typedef void ControlSignature(CellSetIn cellSet, FieldOut<> numEdges);
    typedef _2 ExecutionSignature(CellShape, PointCount);
    using InputDomain = _1;

    template<typename CellShapeTag>
    VTKM_EXEC_CONT
    vtkm::IdComponent operator()(CellShapeTag shape,
                                 vtkm::IdComponent numPoints) const
    {
      return vtkm::exec::CellEdgeNumberOfEdges(numPoints, shape, *this);
    }
  };
  ////
  //// END-EXAMPLE GenerateMeshConstantShapeCount.cxx
  ////

  ////
  //// BEGIN-EXAMPLE GenerateMeshConstantShapeGenIndices.cxx
  ////
  class EdgeIndices : public vtkm::worklet::WorkletMapPointToCell
  {
  public:
    typedef void ControlSignature(CellSetIn cellSet, FieldOut<> connectivityOut);
    typedef void ExecutionSignature(CellShape, PointIndices, _2, VisitIndex);
    using InputDomain = _1;

    using ScatterType = vtkm::worklet::ScatterCounting;
    VTKM_CONT
    ScatterType GetScatter() const { return this->Scatter; }

    VTKM_CONT
    explicit EdgeIndices(const ScatterType& scatter) : Scatter(scatter) {  }

    template<typename CellShapeTag, typename PointIndexVecType>
    VTKM_EXEC
    void operator()(CellShapeTag shape,
                    const PointIndexVecType& pointIndices,
                    vtkm::Vec<vtkm::Id, 2>& connectivityOut,
                    vtkm::IdComponent visitIndex) const
    {
      vtkm::Vec<vtkm::IdComponent, 2> localEdgeIndices =
          vtkm::exec::CellEdgeLocalIndices(pointIndices.GetNumberOfComponents(),
                                           visitIndex,
                                           shape,
                                           *this);
      connectivityOut[0] = pointIndices[localEdgeIndices[0]];
      connectivityOut[1] = pointIndices[localEdgeIndices[1]];
    }

  private:
    ScatterType Scatter;
  };
  ////
  //// END-EXAMPLE GenerateMeshConstantShapeGenIndices.cxx
  ////

  ////
  //// BEGIN-EXAMPLE GenerateMeshConstantShapeInvoke.cxx
  ////
  template<typename CellSetType, typename Device>
  VTKM_CONT
  vtkm::cont::CellSetSingleType<> Run(const CellSetType& inCellSet, Device)
  {
    VTKM_IS_DYNAMIC_OR_STATIC_CELL_SET(CellSetType);

    vtkm::cont::ArrayHandle<vtkm::IdComponent> edgeCounts;
    vtkm::worklet::DispatcherMapTopology<CountEdges, Device> countEdgeDispatcher;
    countEdgeDispatcher.Invoke(inCellSet, edgeCounts);

    vtkm::worklet::ScatterCounting scatter(edgeCounts, Device());
    this->OutputToInputCellMap =
        scatter.GetOutputToInputMap(inCellSet.GetNumberOfCells());

    vtkm::cont::ArrayHandle<vtkm::Id> connectivityArray;
    vtkm::worklet::DispatcherMapTopology<EdgeIndices, Device>
        edgeIndicesDispatcher((EdgeIndices(scatter)));
    edgeIndicesDispatcher.Invoke(
          inCellSet, vtkm::cont::make_ArrayHandleGroupVec<2>(connectivityArray));

    vtkm::cont::CellSetSingleType<> outCellSet(inCellSet.GetName());
    outCellSet.Fill(inCellSet.GetNumberOfPoints(),
                    vtkm::CELL_SHAPE_LINE,
                    2,
                    connectivityArray);

    return outCellSet;
  }
  ////
  //// END-EXAMPLE GenerateMeshConstantShapeInvoke.cxx
  ////

  ////
  //// BEGIN-EXAMPLE GenerateMeshConstantShapeMapCellField.cxx
  ////
  template <typename ValueType, typename Storage, typename Device>
  VTKM_CONT
  vtkm::cont::ArrayHandle<ValueType> ProcessCellField(
      const vtkm::cont::ArrayHandle<ValueType, Storage>& inCellField,
      Device) const
  {
    vtkm::cont::ArrayHandle<ValueType> outCellField;
    vtkm::cont::DeviceAdapterAlgorithm<Device>::Copy(
        vtkm::cont::make_ArrayHandlePermutation(this->OutputToInputCellMap,
                                                inCellField),
        outCellField);
    return outCellField;
  }
  ////
  //// END-EXAMPLE GenerateMeshConstantShapeMapCellField.cxx
  ////

private:
  vtkm::worklet::ScatterCounting::OutputToInputMapType OutputToInputCellMap;
};

} // anonymous namespace

}
} // namespace vtkm::worklet

namespace vtkm {
namespace filter {

//// PAUSE-EXAMPLE
namespace {

//// RESUME-EXAMPLE
class ExtractEdges : public vtkm::filter::FilterDataSet<ExtractEdges>
{
public:
  template <typename Policy, typename Device>
  VTKM_CONT vtkm::filter::Result DoExecute(const vtkm::cont::DataSet& inData,
                                           vtkm::filter::PolicyBase<Policy> policy,
                                           Device);

  template <typename T, typename StorageType, typename Policy, typename Device>
  VTKM_CONT bool DoMapField(vtkm::filter::Result& result,
                            const vtkm::cont::ArrayHandle<T, StorageType>& input,
                            const vtkm::filter::FieldMetadata& fieldMeta,
                            const vtkm::filter::PolicyBase<Policy>& policy,
                            Device);

private:
  vtkm::worklet::ExtractEdges Worklet;
};

//// PAUSE-EXAMPLE
} // anonymous namespace
//// RESUME-EXAMPLE
}
} // namespace vtkm::filter


namespace vtkm {
namespace filter {

//// PAUSE-EXAMPLE
namespace {

//// RESUME-EXAMPLE
template <typename Policy, typename Device>
inline VTKM_CONT vtkm::filter::Result
ExtractEdges::DoExecute(const vtkm::cont::DataSet& inData,
                        vtkm::filter::PolicyBase<Policy> policy,
                        Device)
{
  VTKM_IS_DEVICE_ADAPTER_TAG(Device);

  const vtkm::cont::DynamicCellSet& inCells =
      inData.GetCellSet(this->GetActiveCellSetIndex());

  vtkm::cont::CellSetSingleType<> outCells =
      this->Worklet.Run(vtkm::filter::ApplyPolicy(inCells, policy), Device());

  vtkm::cont::DataSet outData;

  outData.AddCellSet(outCells);

  for (vtkm::IdComponent coordSystemIndex = 0;
       coordSystemIndex < inData.GetNumberOfCoordinateSystems();
       ++coordSystemIndex)
  {
    outData.AddCoordinateSystem(inData.GetCoordinateSystem(coordSystemIndex));
  }

  return vtkm::filter::Result(outData);
}

template <typename T, typename StorageType, typename Policy, typename Device>
inline VTKM_CONT
bool ExtractEdges::DoMapField(vtkm::filter::Result& result,
                              const vtkm::cont::ArrayHandle<T, StorageType>& input,
                              const vtkm::filter::FieldMetadata& fieldMeta,
                              const vtkm::filter::PolicyBase<Policy>&,
                              Device)
{
  vtkm::cont::Field output;

  if (fieldMeta.IsPointField())
  {
    output = fieldMeta.AsField(input); // pass through
  }
  else if (fieldMeta.IsCellField())
  {
    output = fieldMeta.AsField(this->Worklet.ProcessCellField(input, Device()));
  }
  else
  {
    return false;
  }

  result.GetDataSet().AddField(output);

  return true;
}

//// PAUSE-EXAMPLE
} // anonymous namespace

//// RESUME-EXAMPLE
}
} // namespace vtkm::filter

namespace {

void CheckOutput(const vtkm::cont::CellSetSingleType<>& cellSet)
{
  std::cout << "Num cells: " << cellSet.GetNumberOfCells() << std::endl;
  VTKM_TEST_ASSERT(cellSet.GetNumberOfCells() == 12+8+6+9, "Wrong # of cells.");

  auto connectivity = cellSet.GetConnectivityArray(vtkm::TopologyElementTagPoint(),
                                                   vtkm::TopologyElementTagCell());
  std::cout << "Connectivity:" << std::endl;
  vtkm::cont::printSummary_ArrayHandle(connectivity, std::cout, true);

  auto connectivityPortal = connectivity.GetPortalConstControl();
  VTKM_TEST_ASSERT(connectivityPortal.Get(0) == 0, "Bad edge index");
  VTKM_TEST_ASSERT(connectivityPortal.Get(1) == 1, "Bad edge index");
  VTKM_TEST_ASSERT(connectivityPortal.Get(2) == 1, "Bad edge index");
  VTKM_TEST_ASSERT(connectivityPortal.Get(3) == 5, "Bad edge index");
  VTKM_TEST_ASSERT(connectivityPortal.Get(68) == 9, "Bad edge index");
  VTKM_TEST_ASSERT(connectivityPortal.Get(69) == 10, "Bad edge index");
}

void TryWorklet()
{
  std::cout << std::endl << "Trying calling worklet." << std::endl;
  vtkm::cont::DataSet inDataSet =
      vtkm::cont::testing::MakeTestDataSet().Make3DExplicitDataSet5();
  vtkm::cont::CellSetExplicit<> inCellSet;
  inDataSet.GetCellSet().CopyTo(inCellSet);

  vtkm::worklet::ExtractEdges worklet;
  vtkm::cont::CellSetSingleType<> outCellSet =
      worklet.Run(inCellSet, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
  CheckOutput(outCellSet);
}

void TryFilter()
{
  std::cout << std::endl << "Trying calling filter." << std::endl;
  vtkm::cont::DataSet inDataSet =
      vtkm::cont::testing::MakeTestDataSet().Make3DExplicitDataSet5();

  vtkm::filter::ExtractEdges filter;

  vtkm::filter::Result result = filter.Execute(inDataSet);
  VTKM_TEST_ASSERT(result.IsValid(), "Execute failed.");

  for (vtkm::IdComponent fieldIndex = 0;
       fieldIndex < inDataSet.GetNumberOfFields();
       ++fieldIndex)
  {
    filter.MapFieldOntoOutput(result, inDataSet.GetField(fieldIndex));
  }

  vtkm::cont::DataSet outDataSet = result.GetDataSet();
  vtkm::cont::CellSetSingleType<> outCellSet;
  outDataSet.GetCellSet().CopyTo(outCellSet);
  CheckOutput(outCellSet);

  vtkm::cont::Field outCellField = outDataSet.GetField("cellvar");
  VTKM_TEST_ASSERT(
        outCellField.GetAssociation() == vtkm::cont::Field::ASSOC_CELL_SET,
        "Cell field not cell field.");
  vtkm::cont::ArrayHandle<vtkm::Float32> outCellData;
  outCellField.GetData().CopyTo(outCellData);
  std::cout << "Cell field:" << std::endl;
  vtkm::cont::printSummary_ArrayHandle(outCellData, std::cout, true);
  VTKM_TEST_ASSERT(outCellData.GetNumberOfValues() == outCellSet.GetNumberOfCells(),
                   "Bad size of field.");

  auto cellFieldPortal = outCellData.GetPortalConstControl();
  VTKM_TEST_ASSERT(test_equal(cellFieldPortal.Get(0), 100.1), "Bad field value.");
  VTKM_TEST_ASSERT(test_equal(cellFieldPortal.Get(1), 100.1), "Bad field value.");
  VTKM_TEST_ASSERT(test_equal(cellFieldPortal.Get(34), 130.5), "Bad field value.");
}

void DoTest()
{
  TryWorklet();
  TryFilter();
}

} // anonymous namespace

int GenerateMeshConstantShape(int, char*[])
{
  return vtkm::cont::testing::Testing::Run(DoTest);
}
