// SPDX-FileCopyrightText: Copyright (c) Kitware, Inc.
// SPDX-License-Identifier: BSD-3-Clause
#include "vtkClipDataSet.h"
#include "vtkDIYAggregateDataSetFilter.h"
#include "vtkGroupDataSetsFilter.h"
#include "vtkImageData.h"
#include "vtkMPIController.h"
#include "vtkMultiProcessController.h"
#include "vtkPlane.h"
#include "vtkRTAnalyticSource.h"
#include "vtkRedistributeDataSetFilter.h"
#include "vtkUnstructuredGrid.h"
#include <vtk_mpi.h>

#include <cassert>

int TestRedistributeAndAggregate(int argc, char* argv[])
{
  vtkMPIController* Controller = vtkMPIController::New();
  Controller->Initialize(&argc, &argv, 0);
  vtkMultiProcessController::SetGlobalController(Controller);

  int Rank = Controller->GetLocalProcessId();
  int NumberOfProcessors = Controller->GetNumberOfProcesses();

  // create a wavelet source
  vtkNew<vtkRTAnalyticSource> waveletSource;
  waveletSource->SetWholeExtent(0, 58, 0, 56, 0, 50);
  waveletSource->UpdatePiece(Rank, NumberOfProcessors, 0);

  // print the initial number of vertices on each rank
  std::cout << "WAVELET: rank " << Rank << " has "
            << waveletSource->GetOutput()->GetNumberOfElements(0) << " points" << std::endl;

  // Create a partitioned dataset collection, as the test never seems to hang
  // unless the data is of this type.
  vtkNew<vtkGroupDataSetsFilter> groupFilter;
  groupFilter->SetOutputTypeToPartitionedDataSetCollection();
  groupFilter->SetInputConnection(waveletSource->GetOutputPort(0));
  groupFilter->UpdatePiece(Rank, NumberOfProcessors, 0);

  // print the number of vertices on each rank after grouping
  std::cout << "GROUPED: rank " << Rank << " has "
            << groupFilter->GetOutput()->GetNumberOfElements(0) << " points" << std::endl;

  int numTargetProcs = 2;

  // First redistribute the data onto all processes.  The test never seems
  // to hang if we omit this filter, even though the data is likely to be
  // on all processes already.
  vtkNew<vtkRedistributeDataSetFilter> redistributeFilter;
  redistributeFilter->SetController(Controller);
  redistributeFilter->SetInputConnection(groupFilter->GetOutputPort());
  redistributeFilter->SetNumberOfPartitions(-1);

  vtkNew<vtkDIYAggregateDataSetFilter> aggregateFilter;
  aggregateFilter->SetInputConnection(redistributeFilter->GetOutputPort());
  aggregateFilter->SetNumberOfTargetProcesses(numTargetProcs);
  aggregateFilter->Update();

  // Check which processes have data after aggregation
  int destProcId = 0;
  std::vector<vtkIdType> pointCount(NumberOfProcessors, 0);
  vtkIdType numPoints = aggregateFilter->GetOutput()->GetNumberOfElements(0);
  Controller->Gather(&numPoints, pointCount.data(), 1, destProcId);

  if (Rank == destProcId)
  {
    int nonEmptyRanks = 0;

    // print the number of vertices on each rank after redistribution/aggregation
    for (int r = 0; r < NumberOfProcessors; ++r)
    {
      std::cout << "AGGREGATED: rank " << r << " has " << pointCount[r] << " points" << std::endl;
      if (pointCount[r] > 0)
      {
        nonEmptyRanks += 1;
      }
    }

    assert("post: Wrong number of non-empty ranks" && (numTargetProcs == nonEmptyRanks));
  }

  Controller->Finalize();
  Controller->Delete();
  return EXIT_SUCCESS;
}
