// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause

#include "vtkCellData.h"
#include "vtkHyperTreeGrid.h"
#include "vtkHyperTreeGridNonOrientedCursor.h"
#include "vtkHyperTreeGridSource.h"
#include "vtkLogger.h"
#include "vtkMPIController.h"
#include "vtkStringFormatter.h"

#include <sstream>

namespace
{
constexpr int NB_PROCS = 3;

template <int NbTrees>
struct SourceConfig
{
  unsigned int Depth;
  unsigned int BranchFactor;
  std::array<unsigned int, 3> Dimensions;
  std::array<double, 3> GridScale;
  std::string Descriptor;
  std::string Mask;
  std::array<int8_t, NbTrees> ExpectedProcess;
};
}

template <int NbTrees>
bool TestSource(const SourceConfig<NbTrees>& config, int myRank, int nbRanks)
{
  // Create HTG Source with process selection
  vtkNew<vtkHyperTreeGridSource> htGrid;
  htGrid->SetDebug(true);
  htGrid->SetMaxDepth(config.Depth);
  htGrid->SetBranchFactor(config.BranchFactor);
  htGrid->SetDimensions(config.Dimensions.data());
  htGrid->SetGridScale(config.GridScale.data());
  htGrid->SetDescriptor(config.Descriptor.c_str());
  if (!config.Mask.empty())
  {
    htGrid->SetUseMask(true);
    htGrid->SetMask(config.Mask.c_str());
  }

  htGrid->UpdatePiece(myRank, nbRanks, 0);
  auto htg = htGrid->GetHyperTreeGridOutput();

  // Test that the right trees appear on selected process
  vtkIdType inIndex = 0;
  vtkHyperTreeGrid::vtkHyperTreeGridIterator it;
  htg->InitializeTreeIterator(it);
  bool success = true;
  vtkNew<vtkHyperTreeGridNonOrientedCursor> cursor;
  while (it.GetNextTree(inIndex))
  {
    htg->InitializeNonOrientedCursor(cursor, inIndex, true);
    if ((config.ExpectedProcess[inIndex] == myRank) == cursor->IsMasked())
    {
      vtkErrorWithObjectMacro(
        nullptr, "Tree #" << inIndex << " does not appear on the right process");
      success = false;
    }
  }

  return success;
}

/**
 * Test that the total bounding box generated by the distributed source is correct
 */
bool TestSourceBoundingBox(int myRank, int nbRanks, vtkMPIController* controller)
{
  vtkNew<vtkHyperTreeGridSource> htGrid;
  htGrid->SetDebug(true);
  htGrid->SetMaxDepth(3);
  htGrid->SetBranchFactor(2);
  htGrid->SetDimensions(4, 3, 2);
  htGrid->SetGridScale(1, 1, 1);
  htGrid->SetDescriptor("0RR1R.R2R|........ ........ ........ ........ .......R| ........");
  htGrid->SetMask("111011|01010000 11110000 11110000 11000000 11000001| 11000000");
  htGrid->SetUseMask(true);
  htGrid->UpdatePiece(myRank, nbRanks, 0);

  std::array<double, 6> localBounds, totalMinBounds, totalMaxBounds,
    expectedBounds{ 0.5, 3, 0, 1.75, 0, 0.75 };
  auto htg = htGrid->GetHyperTreeGridOutput();
  htg->GetBounds(localBounds.data());

  controller->AllReduce(localBounds.data(), totalMinBounds.data(), 6, vtkCommunicator::MIN_OP);
  controller->AllReduce(localBounds.data(), totalMaxBounds.data(), 6, vtkCommunicator::MAX_OP);
  std::array<double, 6> totalBounds{ totalMinBounds[0], totalMaxBounds[1], totalMinBounds[2],
    totalMaxBounds[3], totalMinBounds[4], totalMaxBounds[5] };

  bool success = true;
  if (myRank == 0)
  {
    for (int i = 0; i < 6; i++)
    {
      if (totalBounds[i] != expectedBounds[i])
      {
        vtkErrorWithObjectMacro(nullptr,
          "Bound #" << i << " of HTG distributed source is incorrect: expected "
                    << expectedBounds[i] << " but got " << totalBounds[i] << " instead.");
        success = false;
      }
    }
  }

  std::array<int, 3> expectedNumberOfTrees{ 2, 3, 1 };
  if (expectedNumberOfTrees[myRank] != htg->GetNumberOfNonEmptyTrees())
  {
    vtkErrorWithObjectMacro(nullptr,
      "Expected to get " << expectedNumberOfTrees[myRank] << " non-empty trees on rank " << myRank
                         << " but got " << htg->GetNumberOfNonEmptyTrees() << " instead.");
    success = false;
  }

  return success;
}

bool TestHighPieceIdDoesNotCrash(int myRank, int nbRanks)
{
  std::ostringstream logStream;

  auto error_message_callback = [](void* userData, const vtkLogger::Message& message)
  {
    std::ostream& s = *reinterpret_cast<std::ostream*>(userData);
    s << message.preamble << message.message << std::endl;
  };

  auto previous_verbosity = vtkLogger::GetCurrentVerbosityCutoff();
  vtkLogger::AddCallback(
    "logStream", error_message_callback, &logStream, vtkLogger::VERBOSITY_ERROR);
  vtkLogger::SetStderrVerbosity(vtkLogger::VERBOSITY_OFF);

  vtkNew<vtkHyperTreeGridSource> htGrid;
  htGrid->SetDebug(true);
  htGrid->SetMaxDepth(1);
  htGrid->SetBranchFactor(2);
  htGrid->SetDimensions(1, 1, 1);
  htGrid->SetDescriptor("3.");
  htGrid->UpdatePiece(myRank, nbRanks, 0);

  vtkLogger::RemoveCallback("logStream");
  vtkLogger::SetStderrVerbosity(previous_verbosity);

  if (logStream.str().find("Can not assign tree to piece") == std::string::npos)
  {
    vtkErrorWithObjectMacro(nullptr,
      "Using a process id in descriptor greater than the number of ranks did not raise an error. "
      "Error log should read 'Can not assign tree to piece'");
    return false;
  }

  return true;
}

int TestHyperTreeGridSourceDistributed(int argc, char* argv[])
{
  // Initialize MPI Controller
  vtkNew<vtkMPIController> controller;
  controller->Initialize(&argc, &argv);
  vtkMultiProcessController::SetGlobalController(controller);

  int myRank = controller->GetLocalProcessId();
  int nbRanks = controller->GetNumberOfProcesses();

  if (nbRanks != NB_PROCS)
  {
    vtkErrorWithObjectMacro(nullptr, "Expected " << NB_PROCS << " processes, got " << nbRanks);
    controller->Finalize();
    return EXIT_FAILURE;
  }

  std::string threadName = "rank-" + vtk::to_string(controller->GetLocalProcessId());
  vtkLogger::SetThreadName(threadName);

  SourceConfig<6> source1{ 6, 2, { 3, 4, 1 }, { 1.5, 1., 10. },
    "0RR1RR0R.|.... .R.. RRRR R... R...|.R.. ...R ..RR .R.. R... .... ....|.... "
    "...R ..R. .... .R.. R...|.... .... .R.. ....|....",
    "111111|1111 1111 1111 1111 1111|1111 1111 1111 1111 1111 1111 1111|1111 "
    "1111 1111 1111 1111 1111|1111 1111 1111 1111|1111",
    { 0, 0, 1, 1, 0, 0 } };

  SourceConfig<4> source2{
    1,
    1,
    { 3, 3, 1 },
    { 1.5, 1., 10. },
    "0..2.1.",
    "1011",
    { 0, -1, 2, 1 },
  };

  bool success = true;
  success &= ::TestSource(source1, myRank, nbRanks);
  success &= ::TestSource(source2, myRank, nbRanks);

  // Default to 0, ignore chars at the end
  source2.Descriptor = "...2.101";
  source2.Mask = "1011";
  source2.ExpectedProcess = { 0, -1, 0, 2 };
  success &= ::TestSource(source2, myRank, nbRanks);

  source2.Descriptor = ".1.0.2.";
  source2.Mask = "1011";
  source2.ExpectedProcess = { 0, -1, 0, 2 };
  success &= ::TestSource(source2, myRank, nbRanks);

  success &= ::TestSourceBoundingBox(myRank, nbRanks, controller);
  success &= ::TestHighPieceIdDoesNotCrash(myRank, nbRanks);

  controller->Finalize();
  return success ? EXIT_SUCCESS : EXIT_FAILURE;
}
