//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#ifndef vtk_m_filter_flow_internal_ParticleMessenger_h
#define vtk_m_filter_flow_internal_ParticleMessenger_h

#include <vtkm/Particle.h>
#include <vtkm/filter/flow/internal/BoundsMap.h>
#include <vtkm/filter/flow/internal/Messenger.h>
#include <vtkm/filter/flow/vtkm_filter_flow_export.h>

#include <vtkm/filter/flow/Tracer.h> //DRP

#include <list>
#include <map>
#include <set>
#include <vector>

namespace vtkm
{
namespace filter
{
namespace flow
{
namespace internal
{

template <typename ParticleType>
class VTKM_FILTER_FLOW_EXPORT ParticleMessenger : public vtkm::filter::flow::internal::Messenger
{
  //sendRank, message
  using MsgCommType = std::pair<int, std::vector<int>>;

  //particle + blockIDs.
  using ParticleCommType = std::pair<ParticleType, std::vector<vtkm::Id>>;

  //sendRank, vector of ParticleCommType.
  using ParticleRecvCommType = std::pair<int, std::vector<ParticleCommType>>;

public:
  VTKM_CONT ParticleMessenger(vtkmdiy::mpi::communicator& comm,
                              const vtkm::filter::flow::internal::BoundsMap& bm);
  VTKM_CONT ~ParticleMessenger() {}

  VTKM_CONT void SetUseSyncCommunication() { this->CommType = Messenger::SYNC; }
  VTKM_CONT void SetUseAsyncProbeCommunication() { this->CommType = Messenger::ASYNC_PROBE; }
  VTKM_CONT void SetUseAsyncCommunication(int msgSz=1,
                                          int numParticles=128,
                                          int numReceivers=64,
                                          int numBlockIds=2)
  {
    this->CommType = Messenger::ASYNC;
    this->AsyncCommParams.MessageSize = msgSz;
    this->AsyncCommParams.NumberOfParticles = numParticles;
    this->AsyncCommParams.NumberOfReceivers = numReceivers;
    this->AsyncCommParams.NumberOfBlockIds = numBlockIds;
  }

  VTKM_CONT void Exchange(std::unordered_map<int, int>& sendParticleInfo,
                          const std::vector<ParticleType>& outData,
                          const std::vector<vtkm::Id>& outRanks,
                          const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& outBlockIDsMap,
                          vtkm::Id numLocalTerm,
                          std::vector<ParticleType>& inData,
                          std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& inDataBlockIDsMap,
                          vtkm::Id& numTerminateMessages,
                          bool blockAndWait = false,
                          bool ifTracingParticle=false);

protected:
#ifdef VTKM_ENABLE_MPI
  static constexpr int MSG_TERMINATE = 1;

  enum { MESSAGE_TAG = 0x42000, PARTICLE_TAG = 0x42001 };

  VTKM_CONT void RegisterMessages();

  // Send/Recv particles
  VTKM_CONT
  template <typename P,
            template <typename, typename>
            class Container,
            typename Allocator = std::allocator<P>>
  inline void SendParticles(int dst, Container<P, Allocator>& c);

  VTKM_CONT
  template <typename P,
            template <typename, typename>
            class Container,
            typename Allocator = std::allocator<P>>
  inline void SendParticles(std::unordered_map<int, Container<P, Allocator>>& m);

  // Send/Recv messages.
  VTKM_CONT void SendMsg(int dst, const std::vector<int>& msg);
  VTKM_CONT void SendAllMsg(const std::vector<int>& msg);
  VTKM_CONT bool RecvMsg(std::vector<MsgCommType>& msgs) { return RecvAny(&msgs, NULL, false); }

  // Send/Recv datasets.
  VTKM_CONT bool RecvAny(std::vector<MsgCommType>* msgs,
                         std::vector<ParticleRecvCommType>* recvParticles,
                         bool blockAndWait);
  const vtkm::filter::flow::internal::BoundsMap& BoundsMap;

#endif

  VTKM_CONT void SerialExchange(
    const std::vector<ParticleType>& outData,
    const std::vector<vtkm::Id>& outRanks,
    const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& outBlockIDsMap,
    vtkm::Id numLocalTerm,
    std::vector<ParticleType>& inData,
    std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& inDataBlockIDsMap,
    bool blockAndWait) const;

  static std::size_t CalcParticleBufferSize(std::size_t nParticles, std::size_t numBlockIds = 2);

private:
  bool Initialized = false;
  VTKM_CONT void Initialize()
  {
    if (this->Initialized)
      return;

    if (this->CommType == ASYNC)
      this->RegisterMessages();
    this->Initialized = true;
  }
};

//methods

VTKM_CONT
template <typename ParticleType>
ParticleMessenger<ParticleType>::ParticleMessenger(
  vtkmdiy::mpi::communicator& comm,
  const vtkm::filter::flow::internal::BoundsMap& boundsMap)
  : Messenger(comm)
#ifdef VTKM_ENABLE_MPI
  , BoundsMap(boundsMap)
#endif
{
#ifdef VTKM_ENABLE_MPI
#else
  (void)(boundsMap);
#endif
}

template <typename ParticleType>
std::size_t ParticleMessenger<ParticleType>::CalcParticleBufferSize(std::size_t nParticles,
                                                                    std::size_t nBlockIds)
{
  ParticleType pTmp;
  std::size_t pSize = ParticleType::Sizeof();

#ifndef NDEBUG
  vtkmdiy::MemoryBuffer buff;
  ParticleType p;
  vtkmdiy::save(buff, p);

  //Make sure the buffer size is correct.
  //If this fires, then the size of the class has changed.
  VTKM_ASSERT(pSize == buff.size());
#endif

  return
    // rank
    sizeof(int)
    //std::vector<ParticleType> p;
    //p.size()
    + sizeof(std::size_t)
    //nParticles of ParticleType
    + nParticles * pSize
    // std::vector<vtkm::Id> blockIDs for each particle.
    // blockIDs.size() for each particle
    + nParticles * sizeof(std::size_t)
    // nBlockIDs of vtkm::Id for each particle.
    + nParticles * nBlockIds * sizeof(vtkm::Id);
}

VTKM_CONT
template <typename ParticleType>
void ParticleMessenger<ParticleType>::SerialExchange(
  const std::vector<ParticleType>& outData,
  const std::vector<vtkm::Id>& vtkmNotUsed(outRanks),
  const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& outBlockIDsMap,
  vtkm::Id vtkmNotUsed(numLocalTerm),
  std::vector<ParticleType>& inData,
  std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& inDataBlockIDsMap,
  bool vtkmNotUsed(blockAndWait)) const
{
  for (auto& p : outData)
  {
    const auto& bids = outBlockIDsMap.find(p.GetID())->second;
    inData.emplace_back(p);
    inDataBlockIDsMap[p.GetID()] = bids;
  }
}

VTKM_CONT
template <typename ParticleType>
void ParticleMessenger<ParticleType>::Exchange(
  std::unordered_map<int, int>& sendParticleInfo,
  const std::vector<ParticleType>& outData,
  const std::vector<vtkm::Id>& outRanks,
  const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& outBlockIDsMap,
  vtkm::Id numLocalTerm,
  std::vector<ParticleType>& inData,
  std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& inDataBlockIDsMap,
  vtkm::Id& numTerminateMessages,
  bool blockAndWait,
  bool ifTracingParticle)
{
  this->Initialize();

  VTKM_ASSERT(outData.size() == outRanks.size());

  numTerminateMessages = 0;
  inDataBlockIDsMap.clear();

  if (this->GetNumRanks() == 1)
    return this->SerialExchange(
      outData, outRanks, outBlockIDsMap, numLocalTerm, inData, inDataBlockIDsMap, blockAndWait);

#ifdef VTKM_ENABLE_MPI

  //dstRank, vector of (particles,blockIDs)
  std::unordered_map<int, std::vector<ParticleCommType>> sendData;

  PTRACER(vtkm::filter::flow::GetTracer().Get()->TimeTraceToBuffer("SendDataStart"));
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
  if (ifTracingParticle)
    vtkm::filter::flow::GetTracer().Get()->ParticleEventToBuffer(vtkm::filter::flow::GetTracer().Get()->GetIterationStep(), this->GetRank(), vtkm::filter::flow::GetTracer().Get()->GetTraceParticleId(), "ParticleSendBegin", vtkm::filter::flow::GetTracer().Get()->GetElapsedTime());
#endif

  std::size_t numP = outData.size();
  for (std::size_t i = 0; i < numP; i++)
  {
    //the jump block num need to increase 1
    const auto& bids = outBlockIDsMap.find(outData[i].GetID())->second;
    sendData[outRanks[i]].emplace_back(std::make_pair(outData[i], bids));
    sendParticleInfo[outRanks[i]] = sendParticleInfo[outRanks[i]] + 1;
  }

  //Do all the sends first.
  if (numLocalTerm > 0)
    this->SendAllMsg({ MSG_TERMINATE, static_cast<int>(numLocalTerm) });
  this->SendParticles(sendData);
  this->CheckPendingSendRequests();

#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
  if (ifTracingParticle)
    vtkm::filter::flow::GetTracer().Get()->ParticleEventToBuffer(vtkm::filter::flow::GetTracer().Get()->GetIterationStep(), this->GetRank(), vtkm::filter::flow::GetTracer().Get()->GetTraceParticleId(), "ParticleSendEnd", vtkm::filter::flow::GetTracer().Get()->GetElapsedTime());
#endif
  PTRACER(vtkm::filter::flow::GetTracer().Get()->TimeTraceToBuffer("SendDataEnd"));
  PTRACER(vtkm::filter::flow::GetTracer().Get()->TimeTraceToBuffer("RecvDataStart"));
  PTRACER(vtkm::filter::flow::GetTracer().Get()->TimeTraceToBuffer("SyncCommStart"));

  //Check if we have anything coming in.
  std::vector<ParticleRecvCommType> particleData;
  std::vector<MsgCommType> msgData;
  if (RecvAny(&msgData, &particleData, blockAndWait))
  {
    for (const auto& it : particleData)
      for (const auto& v : it.second)
      {
        const auto& p = v.first;
        const auto& bids = v.second;
        inData.emplace_back(p);
        inDataBlockIDsMap[p.GetID()] = bids;
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
        if (vtkm::filter::flow::GetTracer().Get()->IfTracingCustomized(p.GetID()))
        {
          double elapsedTime = vtkm::filter::flow::GetTracer().Get()->GetElapsedTime();
          vtkm::filter::flow::GetTracer().Get()->ParticleDetailsToBuffer(vtkm::filter::flow::GetTracer().Get()->GetIterationStep(),
                                            -1,
                                            vtkm::filter::flow::GetTracer().Get()->GetTraceParticleId(),
                                            "RECVOK",
                                            elapsedTime,
                                            0,
                                            0,
                                            0);
        }
#endif
      }

    for (const auto& m : msgData)
    {
      if (m.second[0] == MSG_TERMINATE)
        numTerminateMessages += static_cast<vtkm::Id>(m.second[1]);
    }
  }
  PTRACER(vtkm::filter::flow::GetTracer().Get()->TimeTraceToBuffer("RecvDataEnd"));
  PTRACER(vtkm::filter::flow::GetTracer().Get()->TimeTraceToBuffer("SyncCommEnd"));
#endif
}


#ifdef VTKM_ENABLE_MPI

VTKM_CONT
template <typename ParticleType>
void ParticleMessenger<ParticleType>::RegisterMessages()
{

  //Determine buffer size for msg and particle tags.
  std::size_t messageBuffSz = CalcMessageBufferSize(this->AsyncCommParams.MessageSize + 1);
  std::size_t particleBuffSz = CalcParticleBufferSize(this->AsyncCommParams.NumberOfParticles, this->AsyncCommParams.NumberOfBlockIds);

  //int numRecvs = std::min(64, this->GetNumRanks() - 1);
  int numRecvs = this->AsyncCommParams.NumberOfReceivers;

  this->RegisterTag(ParticleMessenger::MESSAGE_TAG, numRecvs, messageBuffSz);
  this->RegisterTag(ParticleMessenger::PARTICLE_TAG, numRecvs, particleBuffSz);

  this->InitializeBuffers();
}

VTKM_CONT
template <typename ParticleType>
void ParticleMessenger<ParticleType>::SendMsg(int dst, const std::vector<int>& msg)
{
  vtkmdiy::MemoryBuffer buff;

  //Write data.
  vtkmdiy::save(buff, this->GetRank());
  vtkmdiy::save(buff, msg);
  this->SendData(dst, ParticleMessenger::MESSAGE_TAG, buff);
}

VTKM_CONT
template <typename ParticleType>
void ParticleMessenger<ParticleType>::SendAllMsg(const std::vector<int>& msg)
{
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
  vtkm::filter::flow::GetTracer().Get()->AlgorithmRecorderV("SEND_TERM", {msg[1]});
#endif
  for (int i = 0; i < this->GetNumRanks(); i++)
    if (i != this->GetRank())
      this->SendMsg(i, msg);
}

VTKM_CONT
template <typename ParticleType>
bool ParticleMessenger<ParticleType>::RecvAny(std::vector<MsgCommType>* msgs,
                                              std::vector<ParticleRecvCommType>* recvParticles,
                                              bool blockAndWait)
{
  std::set<int> tags;
  if (msgs)
  {
    tags.insert(ParticleMessenger::MESSAGE_TAG);
    msgs->resize(0);
  }
  if (recvParticles)
  {
    tags.insert(ParticleMessenger::PARTICLE_TAG);
    recvParticles->resize(0);
  }

  if (tags.empty())
    return false;

  std::vector<std::pair<int, vtkmdiy::MemoryBuffer>> buffers;
  std::vector<std::pair<int,int>> bufferRecvTimeWindows; //DRP
      if (!this->RecvData(tags, buffers, bufferRecvTimeWindows, blockAndWait))
    return false;

  PTRACER(auto recvTime = vtkm::filter::flow::GetTracer().Get()->GetElapsedTime());
  std::vector<int> particleRecvData;

  int pidx = 0;
  for (auto& buff : buffers)
  {
    if (buff.first == ParticleMessenger::MESSAGE_TAG)
    {
      int sendRank;
      std::vector<int> m;
      vtkmdiy::load(buff.second, sendRank);
      vtkmdiy::load(buff.second, m);
      msgs->emplace_back(std::make_pair(sendRank, m));
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
      std::vector<int> recvData = {sendRank, m[0]};
      vtkm::filter::flow::GetTracer().Get()->AlgorithmRecorderV("RECEIVE_TERM_rank_n", recvData);
#endif
    }
    else if (buff.first == ParticleMessenger::PARTICLE_TAG)
    {
      int sendRank;
      std::vector<ParticleCommType> particles;

      vtkmdiy::load(buff.second, sendRank);
      vtkmdiy::load(buff.second, particles);
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
      for (auto& p : particles)
      {
        p.first.RecvT = recvTime;
        p.first.W += (recvTime - p.first.SendT);
      }
      if (this->UsingAsyncCommunication())
      {
          if (bufferRecvTimeWindows.empty())
              throw vtkm::cont::ErrorFilterExecution("Buffer time window is empty");
          if (pidx >= static_cast<int>(bufferRecvTimeWindows.size()))
              throw vtkm::cont::ErrorFilterExecution("Buffer time window is wrong size");
          auto recvTimeWindow = bufferRecvTimeWindows[pidx];
          for (auto& p : particles)
          {
              p.first.RecvTB0 = recvTimeWindow.first;
              p.first.RecvTB1 = recvTimeWindow.second;
              p.first.WB += (p.first.RecvTB1 - p.first.RecvTB0);
          }
      }

      particleRecvData.push_back(sendRank);
      particleRecvData.push_back(particles.size());
#endif
      recvParticles->emplace_back(std::make_pair(sendRank, particles));
      pidx++;
    }
  }
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
  if (!particleRecvData.empty())
    vtkm::filter::flow::GetTracer().Get()->AlgorithmRecorderV("RECEIVE_PARTICLES_rank_np", particleRecvData);

  if (vtkm::filter::flow::GetTracer().Get()->GetTraceParticleId() >= 0 && !recvParticles->empty())
  {
    auto recvTimeEnd = vtkm::filter::flow::GetTracer().Get()->GetElapsedTime();
    for (std::size_t i = 0; i < recvParticles->size(); i++)
    {
      const auto& pv = (*recvParticles)[i];
      for (const auto& p : pv.second)
        if (p.first.GetID() == vtkm::filter::flow::GetTracer().Get()->GetTraceParticleId())
        {
          vtkm::filter::flow::GetTracer().Get()->ParticleEventToBuffer(vtkm::filter::flow::GetTracer().Get()->GetIterationStep(),
                                          this->GetRank(),
                                          p.first.GetID(),
                                          "DESERIALIZE_PARTICLES",
                                          recvTimeEnd-recvTime);
          break;
        }
    }
  }
#endif

  return true;
}

VTKM_CONT
template <typename ParticleType>
template <typename P, template <typename, typename> class Container, typename Allocator>
inline void ParticleMessenger<ParticleType>::SendParticles(int dst,
                                                           Container<P, Allocator>& c)
{
  if (dst == this->GetRank())
  {
    VTKM_LOG_S(vtkm::cont::LogLevel::Error, "Error. Sending a particle to yourself.");
    return;
  }
  if (c.empty())
    return;

  vtkmdiy::MemoryBuffer bb;
  vtkmdiy::save(bb, this->GetRank());

#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
  auto sendTime = vtkm::filter::flow::GetTracer().Get()->GetElapsedTime();
  for (auto& i : c)
    i.first.SendT = sendTime;
#endif

  vtkmdiy::save(bb, c);
  this->SendData(dst, ParticleMessenger::PARTICLE_TAG, bb);
}

VTKM_CONT
template <typename ParticleType>
template <typename P, template <typename, typename> class Container, typename Allocator>
inline void ParticleMessenger<ParticleType>::SendParticles(
  std::unordered_map<int, Container<P, Allocator>>& m)
{
#ifdef VTKm_INSTRUMENT_PARTICLE_ADVECTION
  std::vector<int> commData;
  for (auto& mit : m)
  {
    commData.push_back(mit.first);
    commData.push_back(mit.second.size());
  }
  if (!commData.empty())
    vtkm::filter::flow::GetTracer().Get()->AlgorithmRecorderV("SEND_PARTICLES_rank_np", commData);
#endif

  for (auto& mit : m)
    if (!mit.second.empty())
      this->SendParticles(mit.first, mit.second);
}
#endif

}
}
}
} // vtkm::filter::flow::internal

#endif // vtk_m_filter_flow_internal_ParticleMessenger_h
