//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_exec_cuda_internal_TaskStrided_h
#define vtk_m_exec_cuda_internal_TaskStrided_h

#include <vtkm/exec/TaskBase.h>

#include <vtkm/cont/cuda/internal/CudaAllocator.h>

#include <vtkmtaotuple/include/Tuple.h>
#include <vtkmtaotuple/include/tao/seq/make_integer_sequence.hpp>

namespace vtkm
{
namespace exec
{
namespace cuda
{
namespace internal
{

template <typename Functor, typename IndexType, typename TupleType, std::size_t... I>
VTKM_EXEC inline void TaskStridedApply(Functor&& functor,
                                       IndexType&& index,
                                       TupleType&& tuple,
                                       tao::seq::index_sequence<I...>)
{
  functor(std::forward<IndexType>(index), vtkmstd::get<I>(tuple)...);
}

template <typename FType>
void TaskStridedSetErrorBuffer(void* f, const vtkm::exec::internal::ErrorMessageBuffer& buffer)
{
  using FunctorType = typename std::remove_cv<FType>::type;
  FunctorType* const functor = static_cast<FunctorType*>(f);
  functor->SetErrorMessageBuffer(buffer);
}

class TaskStrided : public vtkm::exec::TaskBase
{
public:
  void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer& buffer)
  {
    (void)buffer;
    this->SetErrorBufferFunction(this->FPtr, buffer);
  }

protected:
  void* FPtr = nullptr;

  using SetErrorBufferSignature = void (*)(void*, const vtkm::exec::internal::ErrorMessageBuffer&);
  SetErrorBufferSignature SetErrorBufferFunction = nullptr;
};

template <typename FType, typename... Args>
class TaskStrided1D : public TaskStrided
{
  using TType =
    vtkmstd::tuple<typename std::remove_cv<typename std::remove_reference<Args>::type>::type...>;

public:
  TaskStrided1D(const FType& functor, Args... arguments)
    : TaskStrided()
    , Functor(functor)
    , Arguments(arguments...)
  {
    this->SetErrorBufferFunction = &TaskStridedSetErrorBuffer<FType>;
    //Bind the Functor to void*
    this->FPtr = reinterpret_cast<void*>(&this->Functor);
  }

  VTKM_EXEC
  void operator()(vtkm::Id start, vtkm::Id end, vtkm::Id inc) const
  {
    for (vtkm::Id index = start; index < end; index += inc)
    {
      TaskStridedApply(
        this->Functor, index, this->Arguments, tao::seq::make_index_sequence<sizeof...(Args)>{});
    }
  }

private:
  typename std::remove_const<FType>::type Functor;
  // This is held by by value so that when we transfer the invocation object
  // over to CUDA it gets properly copied to the device. While we want to
  // hold by reference to reduce the number of copies, it is not possible
  // currently.
  const TType Arguments;
};

template <typename FType, typename... Args>
class TaskStrided3D : public TaskStrided
{
  using TType =
    vtkmstd::tuple<typename std::remove_cv<typename std::remove_reference<Args>::type>::type...>;

public:
  TaskStrided3D(const FType& functor, Args... arguments)
    : TaskStrided()
    , Functor(functor)
    , Arguments(arguments...)
  {
    this->SetErrorBufferFunction = &TaskStridedSetErrorBuffer<FType>;
    //Bind the Functor to void*
    this->FPtr = reinterpret_cast<void*>(&this->Functor);
  }

  VTKM_EXEC
  void operator()(vtkm::Id start, vtkm::Id end, vtkm::Id inc, vtkm::Id j, vtkm::Id k) const
  {
    vtkm::Id3 index(start, j, k);
    for (; index[0] < end; index[0] += inc)
    {
      TaskStridedApply(
        this->Functor, index, this->Arguments, tao::seq::make_index_sequence<sizeof...(Args)>{});
    }
  }

private:
  typename std::remove_const<FType>::type Functor;
  // This is held by by value so that when we transfer the invocation object
  // over to CUDA it gets properly copied to the device. While we want to
  // hold by reference to reduce the number of copies, it is not possible
  // currently.
  const TType Arguments;
};
}
}
}
} // vtkm::exec::cuda::internal

#endif //vtk_m_exec_cuda_internal_TaskStrided_h
