#include "RangeMap.h"

#include <vtkm/cont/AssignerPartitionedDataSet.h>
#include <vtkm/cont/EnvironmentTracker.h>

VTKM_THIRDPARTY_PRE_INCLUDE
#include <vtkm/thirdparty/diy/diy.h>
#include <vtkm/thirdparty/diy/mpi-cast.h>
VTKM_THIRDPARTY_POST_INCLUDE

namespace beams
{
namespace rendering
{
RangeMap::RangeMap(const beams::mpi::MpiEnv& mpi, const vtkm::Range& localRange)
  : Mpi(mpi)
  , Ranges(static_cast<std::size_t>(mpi.Size), vtkm::Range())
{
  this->Ranges[static_cast<std::size_t>(mpi.Rank)] = localRange;
  this->Build(localRange);
}

void RangeMap::Build(const vtkm::Range& localRange)
{
  std::vector<vtkm::Float64> localRanges(static_cast<std::size_t>(this->Mpi.Size * 2), 0.0);
  std::vector<vtkm::Float64> allRanges(localRanges.size());

  std::size_t idx = static_cast<std::size_t>(this->Mpi.Rank * 2);
  localRanges[idx++] = localRange.Min;
  localRanges[idx++] = localRange.Max;

  vtkmdiy::mpi::all_reduce(this->Mpi.Comm, localRanges, allRanges, std::plus<vtkm::Float64>{});

  this->GlobalRange = vtkm::Range();
  idx = 0;
  for (auto& range : this->Ranges)
  {
    range = vtkm::Range(allRanges[idx], allRanges[idx + 1]);
    this->GlobalRange.Include(range);
    idx += 2;
  }
}

vtkm::Range RangeMap::GetRange(vtkm::Id rank) const
{
  return this->Ranges[static_cast<std::size_t>(rank)];
}

vtkm::Range RangeMap::GetGlobalRange() const
{
  return this->GlobalRange;
}
} // namespace rendering
} // namespace beams