#include "MpiEnv.h"

#include <vtkm/Math.h>
#include <vtkm/cont/EnvironmentTracker.h>
#include <vtkm/cont/ErrorBadValue.h>
#include <vtkm/thirdparty/diy/mpi-cast.h>

#include <limits.h>
#include <unistd.h>

namespace
{
constexpr int UNINTIALIZED_RANK = -1;
}

namespace beams
{
namespace mpi
{
MpiEnv* MpiEnv::Instance = nullptr;

MpiEnv::MpiEnv(int argc, char* argv[])
  : Env(argc, argv)
  , Comm()
  , Shape(TopologyShape::Unknown)
  , Rank(UNINTIALIZED_RANK)
  , Size(0)
  , XLength(0)
  , YLength(0)
  , ZLength(0)
  , XRank(0)
  , YRank(0)
  , ZRank(0)
{
  vtkm::cont::EnvironmentTracker::SetCommunicator(this->Comm);
  this->RawComm = vtkmdiy::mpi::mpi_cast(this->Comm.handle());

  char hostname[HOST_NAME_MAX];
  gethostname(hostname, HOST_NAME_MAX);
  this->Hostname = hostname;
  this->Rank = this->Comm.rank();
  this->Size = this->Comm.size();
  MpiEnv::Instance = this;
}

MpiEnv::~MpiEnv() = default;

void MpiEnv::ReshapeCustom(int xLength, int yLength, int zLength)
{
  this->Shape = TopologyShape::Custom;
  this->XLength = xLength;
  this->YLength = yLength;
  this->ZLength = zLength;
  this->XRank = this->Rank / (this->YLength * this->ZLength);
  this->YRank = (this->Rank % (this->YLength * this->ZLength)) / this->ZLength;
  this->ZRank = this->Rank % this->ZLength;
}

void MpiEnv::ReshapeAsLine()
{
  this->Shape = TopologyShape::Line;
  this->YLength = 1;
  this->XLength = this->Size;
  this->ZLength = 1;
  this->XRank = this->Rank / this->YLength;
  this->YRank = this->Rank - (this->XRank * this->YLength);
  this->ZRank = 0;
}

void MpiEnv::ReshapeAsRectangle()
{
  vtkm::Id sizeSqrt = static_cast<vtkm::Id>(vtkm::Sqrt(this->Size));
  vtkm::Id yLength = sizeSqrt;
  vtkm::Id xLength = this->Size / sizeSqrt;
  if ((xLength * yLength) != this->Size)
  {
    throw vtkm::cont::ErrorBadValue(
      fmt::format("Cannot shape {} ranks into a rectangle topology", this->Size));
  }

  this->Shape = TopologyShape::Rectangle;
  this->YLength = yLength;
  this->XLength = xLength;
  this->ZLength = 1;
  this->XRank = this->Rank / this->YLength;
  this->YRank = this->Rank - (this->XRank * this->YLength);
  this->ZRank = 0;
}

void MpiEnv::ReshapeAsCuboid()
{
  vtkm::Id sizeCbrt = static_cast<vtkm::Id>(vtkm::Cbrt(this->Size));
  vtkm::Id zLength = sizeCbrt;
  vtkm::Id yLength = sizeCbrt;
  vtkm::Id xLength = sizeCbrt;
  if ((xLength * yLength * zLength) != this->Size)
  {
    throw vtkm::cont::ErrorBadValue(
      fmt::format("Cannot shape {} ranks into a cube topology", this->Size));
  }

  this->Shape = TopologyShape::Cuboid;
  this->ZLength = zLength;
  this->YLength = yLength;
  this->XLength = xLength;
  vtkm::Id area = this->XLength * this->YLength;
  this->ZRank = this->Rank / area;
  this->XRank = (this->Rank % area) / this->YLength;
  this->YRank = (this->Rank % area) - (this->XRank * this->YLength);
}
}
} // beams::mpi
