#ifndef beams_fmt_utils_h
#define beams_fmt_utils_h

#include "../mpi/MpiEnv.h"

#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/Timer.h>

#include "fmt/core.h"

#include <functional>
#include <string>

#define FMT_VAR(x) Fmt::Println(#x " = {}", x);

#define FMT_VAR0(x) Fmt::Println0(#x " = {}", x);

#define FMT_VEC(v) Fmt::PrintVectorln(#v, v, false);

#define FMT_VEC0(v) Fmt::PrintVectorln0(#v, v, false);

#define FMT_VEC_F(v) Fmt::PrintVectorln(#v, v, true);

#define FMT_VEC0_F(v) Fmt::PrintVectorln0(#v, v, true);

#define FMT_ARR(a) Fmt::PrintArrayHandleln(#a, a);

#define FMT_ARR0(a) Fmt::PrintArrayHandleln0(#a, a);

#define FMT_ARR_F(a) Fmt::PrintArrayHandleln(#a, a, true);

#define FMT_ARR0_F(a) Fmt::PrintArrayHandleln0(#a, a, true);

#define FMT_TMR(x) Fmt::PrintTimerln(#x, x);

#define FMT_TMR0(x) Fmt::PrintTimerln0(#x, x);


struct Fmt
{
  static void Initialize(beams::mpi::MpiEnv* mpi) { Fmt::Mpi = mpi; }

  template <typename S, typename... Args>
  static void RawPrint(const S& format_str, Args&&... args)
  {
    fmt::print(stderr, format_str, args...);
  }

  template <typename S, typename... Args>
  static void RawPrint0(const S& format_str, Args&&... args)
  {
    if (Mpi->Rank == 0)
    {
      fmt::print(stderr, format_str, args...);
    }
  }

  template <typename S, typename... Args>
  static void RawPrintln(const S& format_str, Args&&... args)
  {
    Fmt::RawPrint(format_str + std::string("\n"), args...);
  }

  template <typename S, typename... Args>
  static void RawPrintln0(const S& format_str, Args&&... args)
  {
    Fmt::RawPrint0(format_str + std::string("\n"), args...);
  }

  template <typename S, typename... Args>
  static void Print(const S& format_str, Args&&... args)
  {
    std::string s = fmt::format(format_str, args...);
    std::string r = std::to_string(Mpi->Rank);
    fmt::print(stderr, "{}: {}", r, s);
  }

  template <typename S, typename... Args>
  static void Println(const S& format_str, Args&&... args)
  {
    Fmt::Print(format_str + std::string("\n"), args...);
  }

  template <typename S, typename... Args>
  static void Println0(const S& format_str, Args&&... args)
  {
    if (Mpi->Rank == 0)
    {
      Fmt::Println(format_str, args...);
    }
  }

  template <typename S, typename... Args>
  static void Printlnr(int rank, const S& format_str, Args&&... args)
  {
    if (Mpi->Rank == rank)
    {
      Fmt::Println(format_str, args...);
    }
  }

  template <typename S>
  static void PrintTimerln(const S& name, const vtkm::cont::Timer& timer)
  {
    Fmt::Println(name + std::string(": {} ms"), Fmt::ToMs(timer.GetElapsedTime()));
  }

  template <typename S>
  static void PrintTimerln0(const S& name, const vtkm::cont::Timer& timer)
  {
    Fmt::Println0(name + std::string(": {} ms"), Fmt::ToMs(timer.GetElapsedTime()));
  }

  template <typename S, typename ArrayHandleType>
  static void PrintArrayHandleln(const S& name, const ArrayHandleType& array, bool full = false)
  {
    std::stringstream ss;
    ss << name << ": ";
    vtkm::cont::printSummary_ArrayHandle(array, ss, full);
    Fmt::Print(ss.str());
  }

  template <typename S, typename ArrayHandleType>
  static void PrintArrayHandleln0(const S& name, const ArrayHandleType& array, bool full = false)
  {
    if (Mpi->Rank == 0)
    {
      Fmt::PrintArrayHandleln(name, array, full);
    }
  }

  template <typename S, typename ArrayHandleType>
  static void PrintArrayHandlelnr(int rank,
                                  const S& name,
                                  const ArrayHandleType& array,
                                  bool full = false)
  {
    if (Mpi->Rank == rank)
    {
      Fmt::PrintArrayHandleln(name, array, full);
    }
  }

  template <typename T, typename S>
  static void PrintVectorln(const S& name, const std::vector<T>& array, bool full = false)
  {
    using IsVec = typename vtkm::internal::SafeVecTraits<T>::HasMultipleComponents;

    std::size_t sz = array.size();
    std::stringstream out;
    out << name << ": [";
    if (full || sz <= 7)
    {
      for (std::size_t i = 0; i < sz; i++)
      {
        vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(i), out, IsVec());
        if (i != (sz - 1))
        {
          out << ", ";
        }
      }
    }
    else
    {
      vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(0), out, IsVec());
      out << ", ";
      vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(1), out, IsVec());
      out << ", ";
      vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(2), out, IsVec());
      out << " ... ";
      vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(sz - 3), out, IsVec());
      out << ", ";
      vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(sz - 2), out, IsVec());
      out << ", ";
      vtkm::cont::detail::printSummary_ArrayHandle_Value(array.at(sz - 1), out, IsVec());
    }
    out << "]\n";
    Fmt::Print(out.str());
  }

  template <typename T, typename S>
  static void PrintVectorln0(const S& name, const std::vector<T>& array, bool full = false)
  {
    if (Mpi->Rank == 0)
    {
      Fmt::PrintVectorln(name, array, full);
    }
  }

  static std::string ToFloatFmtString(int numDecimals)
  {
    return std::string("{:.") + std::to_string(numDecimals) + std::string("f}");
  }

  template <typename T>
  static std::string FormatFloat(T v, int numDecimals = 3)
  {
    return fmt::format(ToFloatFmtString(numDecimals), v);
  }

  template <typename T, vtkm::IdComponent Size>
  static std::string FormatFloat(const vtkm::Vec<T, Size>& v, int numDecimals = 3)
  {
    std::stringstream out;
    out << "[";
    for (vtkm::IdComponent i = 0; i < Size; ++i)
    {
      out << Fmt::FormatFloat(v[i], numDecimals);
      if (i != Size - 1)
      {
        out << ", ";
      }
    }
    out << "]";
    return out.str();
  }

  template <typename Precision>
  static int ToMs(Precision timeInSeconds)
  {
    return static_cast<int>(timeInSeconds * 1000);
  }

private:
  static beams::mpi::MpiEnv* Mpi;
};

#endif // beams_fmt_utils_h