//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
//  Copyright 2014 Sandia Corporation.
//  Copyright 2014 UT-Battelle, LLC.
//  Copyright 2014 Los Alamos National Security.
//
//  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
//  the U.S. Government retains certain rights in this software.
//
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//============================================================================
#ifndef vtk_m_cont_openmp_internal_DeviceAdapterAlgorithmOpenMP_h
#define vtk_m_cont_openmp_internal_DeviceAdapterAlgorithmOpenMP_h

#include <vtkm/cont/internal/IteratorFromArrayPortal.h>
#include <vtkm/cont/openmp/internal/DeviceAdapterTagOpenMP.h>
#include <vtkm/cont/openmp/internal/ArrayManagerExecutionOpenMP.h>
#include <vtkm/exec/internal/ErrorMessageBuffer.h>
#include <vtkm/Extent.h>
#include <vtkm/Math.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleZip.h>
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
#include <vtkm/cont/ErrorExecution.h>
#include <vtkm/cont/internal/DeviceAdapterAlgorithmGeneral.h>

#include <vtkm/cont/openmp/internal/TiledScanKernel.h>

#include <boost/type_traits/remove_reference.hpp>

#include <omp.h>
#include <iostream>
#include <algorithm>
#include <utility>
#include <cmath>

namespace vtkm {
namespace cont {

template<>
struct DeviceAdapterAlgorithm<vtkm::cont::DeviceAdapterTagOpenMP> :
    vtkm::cont::internal::DeviceAdapterAlgorithmGeneral<
        DeviceAdapterAlgorithm<vtkm::cont::DeviceAdapterTagOpenMP>,
        vtkm::cont::DeviceAdapterTagOpenMP>
{
private:
  const static vtkm::Id GRAIN_SIZE = 2048;
  typedef vtkm::cont::DeviceAdapterTagOpenMP DeviceAdapterTag;

  // This is basically the Serial ScheduleKernel
  template<class FunctorType>
  class ScheduleKernel
  {
    public:
      ScheduleKernel(const FunctorType &functor)
        : Functor(functor) {}

      VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const
      {
        this->Functor(index);
      }

    private:
      const FunctorType Functor;
  };

public:
  template<class Functor>
  VTKM_CONT_EXPORT static void Schedule(Functor functor, vtkm::Id numInstances)
  {
    // Literally just Serial's Schedule with a parallel for slapped on
    const vtkm::Id MESSAGE_SIZE = 1024;
    char errorString[MESSAGE_SIZE];
    errorString[0] = '\0';
    vtkm::exec::internal::ErrorMessageBuffer
      errorMessage(errorString, MESSAGE_SIZE);

    functor.SetErrorMessageBuffer(errorMessage);

    const ScheduleKernel<Functor> kernel(functor);

#pragma omp parallel for schedule(guided, GRAIN_SIZE)
    for(vtkm::Id i=0; i < numInstances; ++i)
    {
      kernel(i);
    }

    if (errorMessage.IsErrorRaised())
    {
      throw vtkm::cont::ErrorExecution(errorString);
    }
  }

  template<class Functor>
  VTKM_CONT_EXPORT static void Schedule(Functor functor, vtkm::Id3 rangeMax)
  {
    // Literally just Serial's Schedule with a parallel for slapped on
    // and the nested loops flattened
    const vtkm::Id MESSAGE_SIZE = 1024;
    char errorString[MESSAGE_SIZE];
    errorString[0] = '\0';
    vtkm::exec::internal::ErrorMessageBuffer
      errorMessage(errorString, MESSAGE_SIZE);

    functor.SetErrorMessageBuffer(errorMessage);

    const ScheduleKernel<Functor> kernel(functor);

    const vtkm::Id size = rangeMax[0] * rangeMax[1] * rangeMax[2];
#pragma omp parallel for schedule(guided, GRAIN_SIZE)
    for(vtkm::Id i=0; i < size; ++i)
    {
      kernel(i);
    }

    if (errorMessage.IsErrorRaised())
    {
      throw vtkm::cont::ErrorExecution(errorString);
    }
  }

  //--------------------------------------------------------------------------
  // Scan Inclusive
private:
  template<typename PortalType, typename BinaryFunctor>
  struct ScanKernel : vtkm::exec::FunctorBase
  {
    PortalType Portal;
    BinaryFunctor BinaryOperator;
    vtkm::Id Stride;
    vtkm::Id Offset;
    vtkm::Id Distance;

    VTKM_CONT_EXPORT
    ScanKernel(const PortalType &portal, BinaryFunctor binary_functor,
               vtkm::Id stride, vtkm::Id offset)
      : Portal(portal),
        BinaryOperator(binary_functor),
        Stride(stride),
        Offset(offset),
        Distance(stride/2)
    {  }

    VTKM_EXEC_EXPORT
    void operator()(vtkm::Id index) const
    {
      typedef typename PortalType::ValueType ValueType;

      vtkm::Id leftIndex = this->Offset + index*this->Stride;
      vtkm::Id rightIndex = leftIndex + this->Distance;

      if (rightIndex < this->Portal.GetNumberOfValues())
      {
        ValueType leftValue = this->Portal.Get(leftIndex);
        ValueType rightValue = this->Portal.Get(rightIndex);
        this->Portal.Set(rightIndex, BinaryOperator(leftValue,rightValue) );
      }
    }
  };

public:
  template<typename T, class CIn, class COut>
  VTKM_CONT_EXPORT static T ScanInclusive(
      const vtkm::cont::ArrayHandle<T,CIn> &input,
      vtkm::cont::ArrayHandle<T,COut>& output)
  {
    return ScanInclusive(input, output, vtkm::internal::Add());
  }

  /*
   * This ScanInclusive is a port of Jeff Inman's vectorized MIC scan, however
   * we seem to perform quite a bit worse than Intel's TBB scan on MIC, so
   * this may need to be revisted. On CPU this scan performs close to theirs
   */
  template<typename T, class CIn, class COut, class BinaryFunctor>
  VTKM_CONT_EXPORT static T ScanInclusive(
      const vtkm::cont::ArrayHandle<T,CIn> &input,
      vtkm::cont::ArrayHandle<T,COut>& output,
      BinaryFunctor binary_functor)
  {
    typedef typename
        vtkm::cont::ArrayHandle<T,COut>
            ::template ExecutionTypes<DeviceAdapterTag>::Portal PortalType;
    typedef typename vtkm::cont::ArrayHandle<T, COut>::PortalControl PortalControlType;
    typedef internal::TiledInclusiveScanKernel<PortalType, BinaryFunctor> TiledScanKernelType;
    const vtkm::Id vector_size = TiledScanKernelType::VECTOR_SIZE;

    Copy(input, output);

    const vtkm::Id numValues = output.GetNumberOfValues();
    if (numValues < 1){
      return output.GetPortalConstControl().Get(0);
    }

    PortalType portal = output.PrepareForInPlace(DeviceAdapterTag());
    const vtkm::Id max_threads = omp_get_max_threads();

    // Number of vectors that fit in 32Kb cache
    const vtkm::Id l1_vectors = (32 * 1024 * 1024) / (vector_size * 4);

    const vtkm::Id count512 = output.GetNumberOfValues() / vector_size
      + (output.GetNumberOfValues() % vector_size == 0 ? 0 : 1);
    // number of 512-bit vectors in a "tile"
    const vtkm::Id tile_size = count512 > l1_vectors ? l1_vectors : count512;
    const vtkm::Id num_threads = max_threads > tile_size ? tile_size : max_threads;

    // meta_scan[i] will get the carry-out for the sub-scan performed by
    // thread[i]. NOTE: This *must* have a value for every thread.
    // Round-up size of meta_scan[] to the nearest multiple of VECTOR_SIZE, to allow
    // use of simd_scan().
    const vtkm::Id meta_scan_count = num_threads % vector_size
      ? 1 + ((num_threads + vector_size) & ~0x0f) : num_threads + 1;

    omp_set_num_threads(num_threads);

    vtkm::cont::ArrayHandle<T, COut> meta_scan_handle;
    meta_scan_handle.Allocate(meta_scan_count);
    PortalType meta_scan_portal = meta_scan_handle.PrepareForInPlace(vtkm::cont::DeviceAdapterTagOpenMP());

    // Note: Carry in and tile offset don't have valid values until they're actually set
    // by the output of a previous tile/thread. This is because T() is not the identity
    // type for all operators we may do a scan with
    T carry_in = T();
    T carry_out = T();
    T tile_offset = T();
    T tile_offset_next = T();
    bool tile_offset_valid = false;

    for (vtkm::Id i = 0; i < meta_scan_count; ++i){
      meta_scan_portal.Set(i, T());
    }

    // Loop over each tile
    for (vtkm::Id t = 0; t < count512; t += tile_size){
      // PASS1 (per tile) parallel
      //
      // A "tile" represents a subset of the total input.  Tiles are
      // processed sequentially, and their size is chosen so as to restrict
      // the amount of memory that is written through the cache.  The goal is
      // that whenever we write and subsequently read, the read should find
      // the value in cache.  Tiles are chosen with sizes to allow this.
      //
      // Within a given tile, multiple threads create independent scans of
      // their local region of iterations.  This results in N distinct local
      // scans, where N is the number of threads.  Each scan produces a
      // "carry-out" value into meta_scan, such that meta_scan[i] is the
      // carry-out value of thread[i].
#pragma omp parallel firstprivate(carry_in), private(carry_out)
      {
        TiledScanKernelType kernel = TiledScanKernelType(portal, binary_functor, carry_in, carry_out);
        bool carry_in_valid = false;
#pragma omp for schedule(static)
        for (vtkm::Id i = 0; i < tile_size; ++i){
          // On the first tile for each thread we don't have a valid carry in value
          if (!carry_in_valid){
            kernel.no_carry_in(t + i);
            carry_in = carry_out;
            carry_in_valid = true;
          }
          else {
            kernel.with_carry_in(t + i);
            carry_in = binary_functor(carry_in, carry_out);
          }
        }
        // Now share our carry out with everyone else in the meta scan
        // It's a bit awkward but the carry out of this thread is actually in
        // `carry_in`
        const vtkm::Id thread = omp_get_thread_num();
        meta_scan_portal.Set(thread, carry_in);
      }

      // META-PASS: (per tile) serial
      //
      // Still within the given tile, we now run a scan on the values in
      // meta-scan.  This produces offsets that will be added back into the
      // sub-scans within the tile.  We also want to add the global offset
      // for this tile to all the values in the sub-scans.
      //
      // After the meta-scan, values in the meta-scan provide fully
      // "relocated" offsets, for the corresponding threads within the
      // current tile. The threads will add these offsets to their sub-scan
      // elements in PASS2.
      //
      // The meta-scan also produces a carry-out.  This is the carry-out
      // value for this tile, which is also the global offset of the next tile
      {
        T carry_in_meta = T();
        TiledScanKernelType kernel = TiledScanKernelType(meta_scan_portal,
            binary_functor, carry_in_meta, carry_out);
        kernel.no_carry_in(0);
        carry_in_meta = carry_out;

        // On the first chunk of the scan we don't have a valid carry in
        for (vtkm::Id i = 1; i < meta_scan_count / vector_size + 1; ++i){
          kernel.with_carry_in(i);
          carry_in_meta = binary_functor(carry_in_meta, carry_out);
        }

        // Add the tile offset to each element in the meta scan if we have a valid one
        if (tile_offset_valid){
          for (vtkm::Id i = 1; i < meta_scan_count; ++i){
            T value = meta_scan_portal.Get(i);
            meta_scan_portal.Set(i, binary_functor(tile_offset, value));
          }
          tile_offset_next = binary_functor(tile_offset, carry_in_meta);
        }
        else {
          tile_offset_next = carry_in_meta;
          tile_offset_valid = true;
        }
      }

      // PASS2 (per tile) parallel
      //
      // Now each thread just adds the values from the corresponding
      // meta-scan, to all the elements within its chunk, within the current
      // tile.
      //
      // The resulting elements are "fully relocated", such that, when we are
      // finished with a tile, we don't need to revisit it again.  Thus, we
      // can now forget about cache for this tile.
#pragma omp parallel
      {
        const vtkm::Id thread = omp_get_thread_num();
        // On the first tile thread 0 has no valid offset to apply
        T thread_offset = thread == 0 ? tile_offset : meta_scan_portal.Get(thread - 1);

#pragma omp for schedule(static)
        for (size_t i = 0; i < tile_size; ++i){
          if (thread != 0 || t != 0){
            const vtkm::Id offset = (t + i) * vector_size;
            const vtkm::Id n_valid = offset + vector_size > portal.GetNumberOfValues() ?
              vector_size - (offset + vector_size - portal.GetNumberOfValues()) : vector_size;

            for (vtkm::Id j = 0; j < n_valid; ++j){
              T value = portal.Get(offset + j);
              portal.Set(offset + j, binary_functor(thread_offset, value));
            }
          }
        }
        // carry-out of final iteration becomes new global carry-out
        if (thread == omp_get_num_threads() - 1) {
          carry_out = binary_functor(carry_out, thread_offset);
        }
      }
      tile_offset = tile_offset_next;

      // CLEANUP (per-tile) serial
      {
        // If meta_scan was rounded up, then the last several values are not
        // reset by any thread, during PASS1.  Reset them to zero now.
        for (vtkm::Id i = 0; i < meta_scan_count; ++i){
          meta_scan_portal.Set(i, T());
        }
        // carry_in should be reset to zeros, for the next tile.  You'd think
        // we could use the tile_offset as the new carry_in, but that would
        // mess up the meta-scan.
        carry_in = T();
      }
    }
    // Reset the max # of threads
    omp_set_num_threads(max_threads);
    // The offset that would be used for the next tile is the result of the scan on the array
    return output.GetPortalConstControl().Get(numValues - 1);
  }

  //--------------------------------------------------------------------------
  // Sort
private:
  template<typename PortalType, typename BinaryCompare>
  static vtkm::Id MedianOfThree(const PortalType *portal, BinaryCompare compare,
      const vtkm::Id l, const vtkm::Id m, const vtkm::Id h)
  {
    if (compare(portal->Get(l), portal->Get(m))){
      if (compare(portal->Get(m), portal->Get(h))){
        return m;
      }
      else {
        return compare(portal->Get(l), portal->Get(h)) ? h : l;
      }
    }
    else {
      if (compare(portal->Get(h), portal->Get(m))){
        return m;
      }
      else {
        return compare(portal->Get(h), portal->Get(l)) ? h : l;
      }
    }
  }

  template<typename PortalType, typename BinaryCompare>
  static vtkm::Id PartitionArray(PortalType *portal, BinaryCompare compare,
      const vtkm::Id low, const vtkm::Id high)
  {
    typedef typename PortalType::ValueType ValueType;
    const vtkm::Id length = high - low;
    const vtkm::Id offset = length / 8;
    const vtkm::Id pivot = MedianOfThree(portal, compare,
        MedianOfThree(portal, compare, low, low + offset, low + offset * 2),
        MedianOfThree(portal, compare, low + offset * 3, low + offset * 4, low + offset * 5),
        MedianOfThree(portal, compare, low + offset * 6, low + offset * 7, high));
    vtkm::Id i = low - 1;
    vtkm::Id j = high + 1;
    if (pivot > low){
      ValueType t = portal->Get(low);
      portal->Set(low, portal->Get(pivot));
      portal->Set(pivot, t);
      // std::swap won't take a vtkm::Pair<> ??
      //std::swap(portal->Get(low), portal->Get(pivot));
    }
    while (true){
      do {
        --j;
      }
      while (compare(portal->Get(low), portal->Get(j)));
      do {
        ++i;
      }
      while (compare(portal->Get(i), portal->Get(low)));
      if (i < j){
        ValueType t = portal->Get(i);
        portal->Set(i, portal->Get(j));
        portal->Set(j, t);
        //std::swap(portal->Get(i), portal->Get(j));
      }
      else {
        break;
      }
    }
    ValueType t = portal->Get(low);
    portal->Set(low, portal->Get(j));
    portal->Set(j, t);
    //std::swap(portal->Get(low), portal->Get(j));
    return j;
  }

  // Sort the range of indices passed, range is inclusive
  template<typename PortalType, typename BinaryCompare>
  static void QuicksortParallel(PortalType *portal, BinaryCompare compare,
      const vtkm::Id low, const vtkm::Id high)
  {
    if (low >= high || low < 0 || high < 0){
      return;
    }
    if (high - low > GRAIN_SIZE){
      const vtkm::Id p = PartitionArray(portal, compare, low, high);
#pragma omp task firstprivate(p, low, high, compare)
      {
        QuicksortParallel(portal, compare, low, p);
      }
      QuicksortParallel(portal, compare, p + 1, high);
    }
    else {
      vtkm::cont::ArrayPortalToIterators<PortalType> iterators(*portal);
      internal::WrappedBinaryOperator<bool, BinaryCompare> wrappedCompare(compare);
      typedef typename vtkm::cont::ArrayPortalToIterators<PortalType>::IteratorType Iterator;
      Iterator begin = iterators.GetBegin();
      std::advance(begin, low);
      Iterator end = iterators.GetBegin();
      std::advance(end, high + 1);
      std::sort(begin, end, wrappedCompare);
    }
  }

public:
  template<typename T, class Storage, class BinaryCompare>
  VTKM_CONT_EXPORT static void Sort(
      vtkm::cont::ArrayHandle<T,Storage> &values,
      BinaryCompare binary_compare)
  {
    typedef typename vtkm::cont::ArrayHandle<T,Storage> ArrayType;
    typedef typename ArrayType::template ExecutionTypes<DeviceAdapterTag>
        ::Portal PortalType;

    vtkm::Id numValues = values.GetNumberOfValues();
    if (numValues < 2) { return; }
    PortalType portal = values.PrepareForInPlace(DeviceAdapterTag());

#pragma omp parallel shared(values, portal)
    {
#pragma omp single
      {
        QuicksortParallel(&portal, binary_compare, 0, portal.GetNumberOfValues() - 1);
      }
    }
  }

  template<typename T, class Storage>
  VTKM_CONT_EXPORT static void Sort(
      vtkm::cont::ArrayHandle<T,Storage> &values)
  {
    Sort(values, std::less<T>());
  }

  VTKM_CONT_EXPORT static void Synchronize()
  {
    // We do split/join here so there's nothing to synchronize. The OpenMP
    // backend should behave identically to Intel TBB where we've re-joined
    // all threads before returning from an execution
  }
};
}
}// namespace vtkm::cont

#endif //vtk_m_cont_openmp_internal_DeviceAdapterAlgorithmOpenMP_h

