/*=========================================================================

  Copyright (c) Kitware, Inc.
  All rights reserved.
  See Copyright.txt or http://www.paraview.org/HTML/Copyright.html for details.

     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notice for more information.

=========================================================================*/

#include "AdiosPipeline.h"

#include <catalyst.h>
#include <catalyst_api.h>

#include <algorithm>

// #ifdef USE_MPI
#include <mpi.h>
// #endif

namespace
{
//----------------------------------------------------------------------------
// Used to debug conduit_node, print only the conduit hierarchy name without any details.
void PrintHierarchy(const conduit_cpp::Node root, const std::string& spacing = "")
{
  auto nbChild = root.number_of_children();
  for (conduit_index_t i = 0; i < nbChild; i++)
  {
    const auto& child = root.child(i);
    std::cout << spacing << child.name() << std::endl;
    PrintHierarchy(child, spacing + " ");
  }
}

//----------------------------------------------------------------------------
// Convenient conduit method to recursively extract a specific node by name from a parent node.
bool FindNode(conduit_cpp::Node node, conduit_cpp::Node& output, const std::string nodeName)
{
  auto nbChild = node.number_of_children();
  for (conduit_index_t i = 0; i < nbChild; i++)
  {
    conduit_cpp::Node child = node.child(i);
    if (child.name() == nodeName)
    {
      output = child;
      return true;
    }

    if (::FindNode(child, output, nodeName))
    {
      return true;
    }
  }

  return false;
}
}

//----------------------------------------------------------------------------
bool AdiosPipeline::Initialize(
  const std::string& adiosFileName, const std::string& CatalystInitializeParametersAsString)
{
  if (adiosFileName.empty())
  {
#ifdef USE_MPI
    MPI_Abort(MPI_COMM_WORLD, -1);
#endif
    std::cerr << "File path specify to locate adios xml is missing." << std::endl;
    return false;
  }

  this->AdiosFileName = adiosFileName;
  this->CatalystInitializeParametersAsString = CatalystInitializeParametersAsString;
  this->FirstExecute = true;

  // Do nothing more here because we didn't have a conduit_node with variables descriptions
  return true;
}

//----------------------------------------------------------------------------
bool AdiosPipeline::Execute(int timestep, conduit_cpp::Node root)
{
  // Because we didn't have the root node during the catalyst initialize method, we initialize all
  // variables before we treat the first execute call.
  if (this->FirstExecute)
  {
    if (!this->InitializeVariables(root))
    {
      std::cerr << "Can't initialize all variables correctly." << std::endl;
      return false;
    }

    this->FirstExecute = false;
  }

  conduit_cpp::Node fields;
  if (!::FindNode(root, fields, "fields"))
  {
    std::cerr << "Doesn't find node named 'fields'" << std::endl;
    return false;
  }

  return this->Put(fields, timestep);
}

//----------------------------------------------------------------------------
bool AdiosPipeline::Finalize(const std::string& catalystFinalizeParametersAsString)
{
  if (!catalystFinalizeParametersAsString.empty())
  {
    this->IO.DefineAttribute<std::string>(
      "CatalystFinalizeParameters", catalystFinalizeParametersAsString);
  }

  this->Writer.Close();

  this->VarsChar.clear();
  this->VarsDouble.clear();
  this->VarsFloat.clear();
  this->VarsInt.clear();
  this->VarsLongInt.clear();
  this->VarsShort.clear();
  this->VarsUInt.clear();
  this->VarsULongInt.clear();
  this->VarsUShort.clear();

  this->Count.clear();
  this->Shape.clear();
  this->Start.clear();

  return true;
}

//----------------------------------------------------------------------------
bool AdiosPipeline::InitializeVariables(const conduit_cpp::Node& root)
{
  try
  {
#ifdef USE_MPI
    int procs = 0;

    // Initialize MPI
    MPI_Comm comm = MPI_COMM_WORLD;
    MPI_Comm_set_errhandler(comm, MPI::ERRORS_THROW_EXCEPTIONS);
    MPI_Comm_size(comm, &procs);

#endif

    std::vector<int> dims = this->GetDimensions(root);
    if (dims.empty())
    {
      std::cerr << "We doesn't find node 'dims' inside the node " << root.name()
                << "We need it to be able to define each adios variables." << std::endl;
      return false;
    }

#ifdef USE_MPI
    int mpiDims[3] = {};
#else
    std::vector<int> mpiDims(dims.size(), 1);
#endif

    std::vector<int> periods(dims.size(), 1);
    std::vector<int> coords(dims.size(), 1);

    // Dimension of process grid
    std::vector<size_t> np;
    // Coordinate of this rank in process grid
    std::vector<size_t> p;
    // Dimension of local array
    std::vector<size_t> size;
    // Offset of local array in the global array
    std::vector<size_t> offset;

#ifdef USE_MPI
    MPI_Dims_create(procs, dims.size(), mpiDims);
    for (int i = 0; i < dims.size(); i++)
    {
      np.emplace_back(mpiDims[i]);
    }
#endif

    for (int i = 0; i < dims.size(); i++)
    {
      p.emplace_back(coords[i]);
      np.emplace_back(mpiDims[i]);
      size.emplace_back(dims[i] / np[i]);
      if (p[i] < dims[i] % np[i])
      {
        size[i]++;
      }
      offset.emplace_back((dims[i] / np[i] * p[i]) + std::min(dims[i] % np[i], p[i]));
    }

    // Initialize Adios2
#ifdef USE_MPI
    this->Adios = adios2::ADIOS(this->AdiosFileName, comm);
#else
    this->Adios = adios2::ADIOS(this->AdiosFileName);
#endif

    this->IO = this->Adios.DeclareIO("Writer");
    this->Writer = this->IO.Open("gs.bp", adios2::Mode::Write);

    // DefineVariable() supports only container of size_t
    this->Shape.resize(dims.size());
    this->Start.resize(dims.size());
    this->Count.resize(dims.size());
    std::transform(
      dims.begin(), dims.end(), this->Shape.begin(), [](int x) { return static_cast<size_t>(x); });
    std::transform(offset.begin(), offset.end(), this->Start.begin(),
      [](int x) { return static_cast<size_t>(x); });
    std::transform(
      size.begin(), size.end(), this->Count.begin(), [](int x) { return static_cast<size_t>(x); });

    this->Timestep = this->IO.DefineVariable<int>("step");

    // Other Variables are always stored inside the node fields
    conduit_cpp::Node fields;
    if (!::FindNode(root, fields, "fields"))
    {
      std::cerr << "Doesn't find node named 'fields'" << std::endl;
      return false;
    }

    bool sucessfullyInitiliazed = false;
    for (conduit_index_t i = 0; i < fields.number_of_children(); i++)
    {
      // Every child should be a variable
      const auto& variable = fields.child(i);
      conduit_cpp::Node values;
      if (!::FindNode(variable, values, "values"))
      {
        std::cerr << "Doesn't find node named 'values'" << std::endl;
        return false;
      }

      // Need to split by component if the value isn't a scalar field
      if (values.number_of_children() == 0)
      {
        sucessfullyInitiliazed |= this->FillVariables(values, variable.name());
      }
      else
      {
        for (conduit_index_t componentID = 0; componentID < values.number_of_children();
             componentID++)
        {
          const auto& component = values.child(componentID);
          sucessfullyInitiliazed |=
            this->FillVariables(component, variable.name() + "_" + component.name());
        }
      }
    }

    if (sucessfullyInitiliazed)
    {
      // Set all attributes
      this->IO.DefineAttribute<std::string>(
        "CatalystInitializeParameters", this->CatalystInitializeParametersAsString);
    }

    return sucessfullyInitiliazed;
  }
#ifdef USE_MPI
  catch (const MPI::Exception& mpiError)
  {
    std::cerr << mpiError.Get_error_string() << std::endl;
    MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
    return false;
  }
#endif
  catch (const std::exception& error)
  {
    std::cerr << error.what() << std::endl;
    return false;
  }
}

//----------------------------------------------------------------------------
std::vector<int> AdiosPipeline::GetDimensions(const conduit_cpp::Node& data)
{
  std::vector<int> dims;
  conduit_cpp::Node dimsNode;
  if (!::FindNode(data, dimsNode, "dims"))
  {
    std::cerr << "Doesn't find node named 'dims'" << std::endl;
    return dims;
  }

  for (conduit_index_t i = 0; i < dimsNode.number_of_children(); i++)
  {
    dims.push_back(dimsNode.child(i).as_uint64());
  }

  return dims;
}

//----------------------------------------------------------------------------
bool AdiosPipeline::Put(const conduit_cpp::Node& fields, int timestep)
{
  this->Writer.BeginStep();
  this->Writer.Put<int>(this->Timestep, &timestep);

  bool successfullyPut = true;
  for (conduit_index_t i = 0; i < fields.number_of_children(); i++)
  {
    if (!successfullyPut)
    {
      break;
    }

    // Every child should be a variable
    const auto& variable = fields.child(i);
    conduit_cpp::Node values;
    if (!::FindNode(variable, values, "values"))
    {
      std::cerr << "Doesn't find node named 'values'" << std::endl;
      return false;
    }

    // The variable isn't an scalar field, we split it by component
    if (values.number_of_children() == 0)
    {
      successfullyPut &= this->PutVariables(values, variable.name());
    }
    else
    {
      for (conduit_index_t componentID = 0; componentID < values.number_of_children();
           componentID++)
      {
        const auto& component = values.child(componentID);
        successfullyPut &= this->PutVariables(component, variable.name() + "_" + component.name());
      }
    }
  }

  this->Writer.EndStep();

  return successfullyPut;
}

//----------------------------------------------------------------------------
bool AdiosPipeline::FillVariables(const conduit_cpp::Node& values, const std::string& name)
{
  auto conduitType = values.dtype().name();

  if (strcmp(conduitType.c_str(), "int8") == 0)
  {
    this->VarsChar.insert(std::pair<std::string, adios2::Variable<char>>(
      name, this->IO.DefineVariable<char>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "int16") == 0)
  {
    this->VarsShort.insert(std::pair<std::string, adios2::Variable<short>>(
      name, this->IO.DefineVariable<short>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "int32") == 0)
  {
    this->VarsInt.insert(std::pair<std::string, adios2::Variable<int>>(
      name, this->IO.DefineVariable<int>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "int64") == 0)
  {
    this->VarsLongInt.insert(std::pair<std::string, adios2::Variable<long int>>(
      name, this->IO.DefineVariable<long int>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "uint16") == 0)
  {
    this->VarsUShort.insert(std::pair<std::string, adios2::Variable<unsigned short>>(
      name, this->IO.DefineVariable<unsigned short>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "uint32") == 0)
  {
    this->VarsUInt.insert(std::pair<std::string, adios2::Variable<unsigned int>>(
      name, this->IO.DefineVariable<unsigned int>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "uint64") == 0)
  {
    this->VarsULongInt.insert(std::pair<std::string, adios2::Variable<unsigned long int>>(name,
      this->IO.DefineVariable<unsigned long int>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "float32") == 0)
  {
    this->VarsFloat.insert(std::pair<std::string, adios2::Variable<float>>(
      name, this->IO.DefineVariable<float>(name, this->Shape, this->Start, this->Count)));
  }
  else if (strcmp(conduitType.c_str(), "float64") == 0)
  {
    this->VarsDouble.insert(std::pair<std::string, adios2::Variable<double>>(
      name, this->IO.DefineVariable<double>(name, this->Shape, this->Start, this->Count)));
  }
  else
  {
    std::cerr << "values type " << conduitType << " isn't currently supported." << std::endl;
    return false;
  }

  return true;
}

//----------------------------------------------------------------------------
bool AdiosPipeline::PutVariables(const conduit_cpp::Node& values, const std::string& name)
{
  auto conduitType = values.dtype().name();

  if (strcmp(conduitType.c_str(), "int8") == 0)
  {
    this->Writer.Put(this->VarsChar.find(name.c_str())->second, values.as_char8_str());
  }
  else if (strcmp(conduitType.c_str(), "int16") == 0)
  {
    this->Writer.Put(this->VarsShort.find(name.c_str())->second, values.as_int16_ptr());
  }
  else if (strcmp(conduitType.c_str(), "int32") == 0)
  {
    this->Writer.Put(this->VarsInt.find(name.c_str())->second, values.as_int32_ptr());
  }
  else if (strcmp(conduitType.c_str(), "int64") == 0)
  {
    this->Writer.Put(this->VarsLongInt.find(name.c_str())->second, values.as_int64_ptr());
  }
  else if (strcmp(conduitType.c_str(), "uint16") == 0)
  {
    this->Writer.Put(this->VarsUShort.find(name.c_str())->second, values.as_uint16_ptr());
  }
  else if (strcmp(conduitType.c_str(), "uint32") == 0)
  {
    this->Writer.Put(this->VarsUInt.find(name.c_str())->second, values.as_uint32_ptr());
  }
  else if (strcmp(conduitType.c_str(), "uint64") == 0)
  {
    this->Writer.Put(this->VarsULongInt.find(name.c_str())->second, values.as_uint64_ptr());
  }
  else if (strcmp(conduitType.c_str(), "float32") == 0)
  {
    this->Writer.Put(this->VarsFloat.find(name.c_str())->second, values.as_float32_ptr());
  }
  else if (strcmp(conduitType.c_str(), "float64") == 0)
  {
    this->Writer.Put(this->VarsDouble.find(name.c_str())->second, values.as_float64_ptr());
  }
  else
  {
    std::cerr << "values type " << conduitType << " isn't currently supported." << std::endl;
    return false;
  }

  return true;
}
