//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#ifndef vtk_m_filter_flow_internal_AdvectAlgorithm_h
#define vtk_m_filter_flow_internal_AdvectAlgorithm_h

#include <vtkm/cont/PartitionedDataSet.h>
#include <vtkm/filter/flow/Tracer.h>
#include <vtkm/filter/flow/internal/BoundsMap.h>
#include <vtkm/filter/flow/internal/DataSetIntegrator.h>
#include <vtkm/filter/flow/internal/ParticleMessenger.h>

namespace vtkm
{
namespace filter
{
namespace flow
{
namespace internal
{

template <typename DSIType, template <typename> class ResultType, typename ParticleType>
class AdvectAlgorithm
{
public:
  AdvectAlgorithm(const vtkm::filter::flow::internal::BoundsMap& bm,
                  std::vector<DSIType>& blocks,
                  bool useAsyncComm)
    : Blocks(blocks)
    , BoundsMap(bm)
    , NumRanks(this->Comm.size())
    , Rank(this->Comm.rank())
    , UseAsynchronousCommunication(useAsyncComm)
  {
  }

  void Execute(vtkm::Id numSteps,
               vtkm::FloatDefault stepSize,
               const vtkm::cont::ArrayHandle<ParticleType>& seeds)
  {
    this->SetNumberOfSteps(numSteps);
    this->SetStepSize(stepSize);
    this->SetSeeds(seeds);
    this->Go();
  }

  vtkm::cont::PartitionedDataSet GetOutput() const
  {
    vtkm::cont::PartitionedDataSet output;

    for (const auto& b : this->Blocks)
    {
      vtkm::cont::DataSet ds;
      if (b.template GetOutput<ParticleType>(ds))
        output.AppendPartition(ds);
    }

    return output;
  }

  void SetStepSize(vtkm::FloatDefault stepSize) { this->StepSize = stepSize; }
  void SetNumberOfSteps(vtkm::Id numSteps) { this->NumberOfSteps = numSteps; }
  void SetSeeds(const vtkm::cont::ArrayHandle<ParticleType>& seeds)
  {
    this->ClearParticles();

    vtkm::Id n = seeds.GetNumberOfValues();
    auto portal = seeds.ReadPortal();

    std::vector<std::vector<vtkm::Id>> blockIDs;
    std::vector<ParticleType> particles;
    for (vtkm::Id i = 0; i < n; i++)
    {
      const ParticleType p = portal.Get(i);
      std::vector<vtkm::Id> ids = this->BoundsMap.FindBlocks(p.GetPosition());

      //Note: For duplicate blocks, this will give the seeds to the rank that are first in the list.
      if (!ids.empty())
      {
        auto ranks = this->BoundsMap.FindRank(ids[0]);
        if (!ranks.empty() && this->Rank == ranks[0])
        {
          particles.emplace_back(p);
          blockIDs.emplace_back(ids);
        }
      }
    }
    this->SetSeedArray(particles, blockIDs);
  }

  //Advect all the particles.
  virtual void Go()
  {
    vtkm::filter::flow::internal::ParticleMessenger<ParticleType> messenger(
      this->Comm, this->UseAsynchronousCommunication, this->BoundsMap, 1, 128);

    vtkm::Id nLocal = static_cast<vtkm::Id>(this->Active.size() + this->Inactive.size());
    this->ComputeTotalNumParticles(nLocal);

    //Double barriers to sync clocks.
    this->Comm.barrier();
    this->Comm.barrier();
    g_Tracer->StartTimer();
    g_Tracer->TimeTraceToBuffer("GoStart");
    vtkm::Id round = 0;
    while (this->TotalNumTerminatedParticles < this->TotalNumParticles)
    {
      //compute the counter information
      vtkm::Id TotalAdvectedSteps = 0;

      g_Tracer->TimeTraceToBuffer("AdvectStart");

      std::vector<ParticleType> v;
      vtkm::Id numTerm = 0, blockId = -1;

      bool ifTracingParticle = false;
      g_Tracer->SetBegOverheadStart();
      g_Tracer->SetTracingStatus(false);
      if (this->GetActiveParticles(v, blockId))
      {
        //make this a pointer to avoid the copy?
        auto& block = this->GetDataSet(blockId);
        DSIHelperInfoType bb =
          DSIHelperInfo<ParticleType>(v, this->BoundsMap, this->ParticleBlockIDsMap);

        //g_Tracer->TimeTraceToBuffer("AdvectStart");
        //caculate the advected steps for all particles
        vtkm::Id num0 = 0, num1 = 0, numSmall0 = 0, numSmall1 = 0;
        for (const auto& p : v)
        {
          num0 += p.GetNumberOfSteps();
          numSmall0 += p.GetNumSmallSteps();
        }

        vtkm::Id tracingPid = g_Tracer->GetTraceParticleId();
        vtkm::Id particleStepsBefore = 0, particleStepsAfter = 0;
        if (tracingPid >= 0)
        {
          //if dumping particles
          double aliveTime = g_Tracer->GetElapsedTime();
          for (const auto& p : v)
          {
            //if (tracingPid == p.GetID())
            if (g_Tracer->IfTracingCustomized(p.GetID()))
            {
              //for the inline case, the rank id is same with process id
              //for the intransit case, the rank id is different with process id
              //also adding positions
              g_Tracer->ParticleInBlockToBuffer(g_Tracer->GetIterationStep(),
                                                   this->Rank,
                                                   p.GetID(),
                                                   p.GetNumberOfSteps(),
                                                   p.GetPosition(),
                                                   aliveTime);
              particleStepsBefore = p.GetNumberOfSteps();

              vtkm::Id pnum = (vtkm::Id)(v.size());
              double elapsedTime = g_Tracer->GetElapsedTime();
              g_Tracer->ParticleDetailsToBuffer(g_Tracer->GetIterationStep(),
                                                   this->Rank,
                                                   p.GetID(),
                                                   "ADVECTSTART",
                                                   elapsedTime,
                                                   particleStepsBefore,
                                                   pnum);
              ifTracingParticle = true;
              g_Tracer->SetTracingStatus(true);
            }
          }
        }

        block.Advect(bb, this->StepSize, this->NumberOfSteps);

        if (tracingPid >= 0)
        {
          for (const auto& p : bb.Get<DSIHelperInfo<ParticleType>>().Particles)
          {
            //if (tracingPid == p.GetID())
            if (g_Tracer->IfTracingCustomized(p.GetID()))
            {
              particleStepsAfter = p.GetNumberOfSteps();
              vtkm::Id pnum = (vtkm::Id)(v.size());
              double elapsedTime = g_Tracer->GetElapsedTime();
              g_Tracer->ParticleDetailsToBuffer(g_Tracer->GetIterationStep(),
                                                   this->Rank,
                                                   p.GetID(),
                                                   "ADVECTEND",
                                                   elapsedTime,
                                                   particleStepsAfter,
                                                   pnum);
            }
          }
        }

        g_Tracer->SetEndOverheadStart();
        numTerm = this->UpdateResult(bb.Get<DSIHelperInfo<ParticleType>>());

        // the information in bb is updated
        for (const auto& p : bb.Get<DSIHelperInfo<ParticleType>>().Particles)
        {
          num1 += p.GetNumberOfSteps();
          numSmall1 += p.GetNumSmallSteps();
        }
        TotalAdvectedSteps = num1 - num0;
        std::string infoStr = "ParticleAdvectInfo_" + std::to_string((int)(v.size())) + "_" +
          std::to_string(TotalAdvectedSteps) + "_" + std::to_string(numSmall1 - numSmall0);
        g_Tracer->TimeTraceToBuffer(infoStr);
        g_Tracer->SetEndOverheadEnd();

        //Update particle info..
        this->UpdateParticleTimers(v);
      }
      else
        g_Tracer->SetBegOverheadEnd();

      g_Tracer->TimeTraceToBuffer("AdvectEnd");

      vtkm::Id numTermMessages = 0;
      g_Tracer->TimeTraceToBuffer("CommStart");
      if (ifTracingParticle)
      {
        double elapsedTime = g_Tracer->GetElapsedTime();
        g_Tracer->ParticleDetailsToBuffer(
          g_Tracer->GetIterationStep(), this->Rank, -1, "GANG_COMM_START", elapsedTime, 0, 0);
      }
      std::unordered_map<int, int> sendParticleInfo;
      this->Communicate(messenger, numTerm, numTermMessages, sendParticleInfo);
      g_Tracer->TimeTraceToBuffer("CommEnd");
      if (ifTracingParticle)
      {
        double elapsedTime = g_Tracer->GetElapsedTime();
        g_Tracer->ParticleDetailsToBuffer(
          g_Tracer->GetIterationStep(), this->Rank, -1, "GANG_COMM_END", elapsedTime, 0, 0);
      }
      this->TotalNumTerminatedParticles += (numTerm + numTermMessages);
      if (this->TotalNumTerminatedParticles > this->TotalNumParticles)
        throw vtkm::cont::ErrorFilterExecution("Particle count error");

      round++;
      //      g_Tracer->TimeTraceToBuffer("WhileEnd");
      g_Tracer->CounterToBuffer(round, TotalAdvectedSteps, sendParticleInfo);
    }
    //    g_Tracer->TimeTraceToBuffer("GoEnd");
  }

  virtual void ClearParticles()
  {
    this->Active.clear();
    this->Inactive.clear();
    this->ParticleBlockIDsMap.clear();
  }

  void ComputeTotalNumParticles(const vtkm::Id& numLocal)
  {
    long long total = static_cast<long long>(numLocal);
#ifdef VTKM_ENABLE_MPI
    MPI_Comm mpiComm = vtkmdiy::mpi::mpi_cast(this->Comm.handle());
    MPI_Allreduce(MPI_IN_PLACE, &total, 1, MPI_LONG_LONG, MPI_SUM, mpiComm);
#endif
    this->TotalNumParticles = static_cast<vtkm::Id>(total);
  }

  DataSetIntegrator<DSIType>& GetDataSet(vtkm::Id id)
  {
    for (auto& it : this->Blocks)
      if (it.GetID() == id)
        return it;

    throw vtkm::cont::ErrorFilterExecution("Bad block");
  }

  virtual void SetSeedArray(const std::vector<ParticleType>& particles,
                            const std::vector<std::vector<vtkm::Id>>& blockIds)
  {
    VTKM_ASSERT(particles.size() == blockIds.size());

    auto pit = particles.begin();
    auto bit = blockIds.begin();
    while (pit != particles.end() && bit != blockIds.end())
    {
      this->ParticleBlockIDsMap[pit->GetID()] = *bit;
      pit++;
      bit++;
    }

    //Update Active
    this->Active.insert(this->Active.end(), particles.begin(), particles.end());
  }

  virtual bool GetActiveParticles(std::vector<ParticleType>& particles, vtkm::Id& blockId)
  {
    g_Tracer->TimeTraceToBuffer("GetActiveParticlesStart");

    particles.clear();
    blockId = -1;
    if (this->Active.empty())
      return false;
    particles = std::move(this->Active);
    blockId =  this->ParticleBlockIDsMap[particles[0].GetID()][0];

  /*
    blockId = this->ParticleBlockIDsMap[this->Active.front().GetID()][0];
    auto it = this->Active.begin();
    while (it != this->Active.end())
    {
      auto p = *it;
      if (blockId == this->ParticleBlockIDsMap[p.GetID()][0])
      {
        particles.emplace_back(p);
        it = this->Active.erase(it);
      }
      else
        it++;
    }
  */

    g_Tracer->TimeTraceToBuffer("GetActiveParticlesStop");
    return !particles.empty();
  }

  //for sendParticleInfo, key is the dest id, value is number of particles to dest
  void Communicate(vtkm::filter::flow::internal::ParticleMessenger<ParticleType>& messenger,
                   vtkm::Id numLocalTerminations,
                   vtkm::Id& numTermMessages,
                   std::unordered_map<int, int>& sendParticleInfo)
  {
    std::vector<ParticleType> outgoing;
    std::vector<vtkm::Id> outgoingRanks;
    this->GetOutgoingParticles(outgoing, outgoingRanks);

    std::vector<ParticleType> incoming;
    std::unordered_map<vtkm::Id, std::vector<vtkm::Id>> incomingBlockIDs;
    numTermMessages = 0;
    bool block = false;
#ifdef VTKM_ENABLE_MPI
    block = this->GetBlockAndWait(messenger.UsingSyncCommunication(), numLocalTerminations);
#endif
    messenger.Exchange(sendParticleInfo,
                       outgoing,
                       outgoingRanks,
                       this->ParticleBlockIDsMap,
                       numLocalTerminations,
                       incoming,
                       incomingBlockIDs,
                       numTermMessages,
                       block);

    //Cleanup what was sent.
    for (const auto& p : outgoing)
      this->ParticleBlockIDsMap.erase(p.GetID());

    this->UpdateActive(incoming, incomingBlockIDs);
  }

  void GetOutgoingParticles(std::vector<ParticleType>& outgoing,
                            std::vector<vtkm::Id>& outgoingRanks)
  {
    outgoing.clear();
    outgoingRanks.clear();

    outgoing.reserve(this->Inactive.size());
    outgoingRanks.reserve(this->Inactive.size());

    std::vector<ParticleType> particlesStaying;
    std::unordered_map<vtkm::Id, std::vector<vtkm::Id>> particlesStayingBlockIDs;
    //Send out Everything.
    //for (const auto& p : this->Inactive)
    for (auto& p : this->Inactive)
    {
      const auto& bid = this->ParticleBlockIDsMap[p.GetID()];
      VTKM_ASSERT(!bid.empty());
      auto ranks = this->BoundsMap.FindRank(bid[0]);
      VTKM_ASSERT(!ranks.empty());

      if (ranks.size() == 1)
      {
        if (ranks[0] == this->Rank)
        {
          //goes to the block located in the same rank
          //do not add the communication
          p.AddNumTraveledBlocks();
          particlesStaying.emplace_back(p);
          particlesStayingBlockIDs[p.GetID()] = this->ParticleBlockIDsMap[p.GetID()];
        }
        else
        {
          p.AddNumTraveledBlocks();
          //also adding one more comm time
          p.AddNumComm();
          outgoing.emplace_back(p);
          outgoingRanks.emplace_back(ranks[0]);
        }
      }
      else
      {
        //Decide where it should go...

        //Random selection:
        vtkm::Id outRank = std::rand() % ranks.size();
        if (outRank == this->Rank)
        {
          //the dedicated block is in same rank
          p.AddNumTraveledBlocks();
          particlesStayingBlockIDs[p.GetID()] = this->ParticleBlockIDsMap[p.GetID()];
          particlesStaying.emplace_back(p);
        }
        else
        {
          //comm num + 1
          p.AddNumTraveledBlocks();
          p.AddNumComm();
          outgoing.emplace_back(p);
          outgoingRanks.emplace_back(outRank);
        }
      }
    }
    this->Inactive.clear();
    VTKM_ASSERT(outgoing.size() == outgoingRanks.size());
    VTKM_ASSERT(particlesStaying.size() == particlesStayingBlockIDs.size());
    if (!particlesStaying.empty())
      this->UpdateActive(particlesStaying, particlesStayingBlockIDs);
  }

  virtual void UpdateActive(const std::vector<ParticleType>& particles,
                            const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& idsMap)
  {
    this->Update(this->Active, particles, idsMap);
  }

  virtual void UpdateInactive(const std::vector<ParticleType>& particles,
                              const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& idsMap)
  {
    this->Update(this->Inactive, particles, idsMap);
  }

  void Update(std::vector<ParticleType>& arr,
              const std::vector<ParticleType>& particles,
              const std::unordered_map<vtkm::Id, std::vector<vtkm::Id>>& idsMap)
  {
    VTKM_ASSERT(particles.size() == idsMap.size());

    arr.insert(arr.end(), particles.begin(), particles.end());
    for (const auto& it : idsMap)
      this->ParticleBlockIDsMap[it.first] = it.second;
  }

  vtkm::Id UpdateResult(const DSIHelperInfo<ParticleType>& stuff)
  {
    this->UpdateActive(stuff.InBounds.Particles, stuff.InBounds.BlockIDs);
    this->UpdateInactive(stuff.OutOfBounds.Particles, stuff.OutOfBounds.BlockIDs);

    vtkm::Id numTerm = static_cast<vtkm::Id>(stuff.TermID.size());
    //Update terminated particles.
    if (numTerm > 0)
    {
      for (const auto& id : stuff.TermID)
        this->ParticleBlockIDsMap.erase(id);
    }

    return numTerm;
  }


  virtual bool GetBlockAndWait(const bool& syncComm, const vtkm::Id& numLocalTerm)
  {
    bool haveNoWork = this->Active.empty() && this->Inactive.empty();

    //Using syncronous communication we should only block and wait if we have no particles
    if (syncComm)
    {
      return haveNoWork;
    }
    else
    {
      //Otherwise, for asyncronous communication, there are only two cases where blocking would deadlock.
      //1. There are active particles.
      //2. numLocalTerm + this->TotalNumberOfTerminatedParticles == this->TotalNumberOfParticles
      //So, if neither are true, we can safely block and wait for communication to come in.

      if (haveNoWork &&
          (numLocalTerm + this->TotalNumTerminatedParticles < this->TotalNumParticles))
        return true;

      return false;
    }
  }

  void UpdateParticleTimers(std::vector<ParticleType>& particles) const
  {
    auto boT = g_Tracer->BegOverhead[1]-g_Tracer->BegOverhead[0];
    auto eoT = g_Tracer->EndOverhead[1]-g_Tracer->EndOverhead[0];
    auto advT = g_Tracer->Advect[1]-g_Tracer->Advect[0];

    for (auto& p : particles)
    {
      p.BO += boT;
      p.EO += eoT;
      p.A += advT;
//      if (p.NumComm > 2) //GetID() > 80 && p.GetID() < 100)
//        std::cout<<p<<" #comm="<<p.NumComm<<" BO/EO="<<p.BO<<" "<<p.EO<<" A="<<p.A<<" W="<<p.W<<std::endl;
    }
  }

  //Member data
  std::vector<ParticleType> Active;
  std::vector<DSIType> Blocks;
  vtkm::filter::flow::internal::BoundsMap BoundsMap;
  vtkmdiy::mpi::communicator Comm = vtkm::cont::EnvironmentTracker::GetCommunicator();
  std::vector<ParticleType> Inactive;
  vtkm::Id NumberOfSteps;
  vtkm::Id NumRanks;
  //{particleId : {block IDs}}
  std::unordered_map<vtkm::Id, std::vector<vtkm::Id>> ParticleBlockIDsMap;
  vtkm::Id Rank;
  vtkm::FloatDefault StepSize;
  vtkm::Id TotalNumParticles = 0;
  vtkm::Id TotalNumTerminatedParticles = 0;
  bool UseAsynchronousCommunication = true;
};

}
}
}
} //vtkm::filter::flow::internal

#endif //vtk_m_filter_flow_internal_AdvectAlgorithm_h
