//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
//  Copyright 2014 Sandia Corporation.
//  Copyright 2014 UT-Battelle, LLC.
//  Copyright 2014 Los Alamos National Security.
//
//  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
//  the U.S. Government retains certain rights in this software.
//
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//============================================================================
#ifndef vtk_m_cont_openmp_internal_TiledScanKernel_h
#define vtk_m_cont_openmp_internal_TiledScanKernel_h

#include <iostream>

namespace vtkm {
namespace cont {
namespace internal {

/*
 * This kernel performs an inclusive scan on 16 elements
 */
template<typename PortalType, typename BinaryFunctor>
struct TiledInclusiveScanKernel : vtkm::exec::FunctorBase
{
  typedef typename PortalType::ValueType ValueType;
  const static vtkm::Id VECTOR_SIZE = 16;

  PortalType Portal;
  const BinaryFunctor BinaryOperator;
  const ValueType &CarryIn;
  ValueType &CarryOut;

  VTKM_CONT_EXPORT
  TiledInclusiveScanKernel(PortalType &portal, BinaryFunctor binary_functor,
      const ValueType &carry_in, ValueType &carry_out)
  : Portal(portal), BinaryOperator(binary_functor),
  CarryIn(carry_in), CarryOut(carry_out)
  {}

  VTKM_EXEC_EXPORT
  void no_carry_in(vtkm::Id index) const {
    // Perform a scan within the tile
    const vtkm::Id offset = index * VECTOR_SIZE;
    const vtkm::Id n_valid = offset + VECTOR_SIZE > Portal.GetNumberOfValues() ?
      VECTOR_SIZE - (offset + VECTOR_SIZE - Portal.GetNumberOfValues()) : VECTOR_SIZE;
    ValueType vector[VECTOR_SIZE];
    if (n_valid == VECTOR_SIZE){
      scan16(vector, offset);
    }
    else {
      scan_remainder(vector, offset, n_valid);
    }

    CarryOut = vector[n_valid - 1];

    store_vector(vector, offset, n_valid);
  }

  VTKM_EXEC_EXPORT
  void with_carry_in(vtkm::Id index) const
  {
    // Perform a scan within the tile
    const vtkm::Id offset = index * VECTOR_SIZE;
    const vtkm::Id n_valid = offset + VECTOR_SIZE > Portal.GetNumberOfValues() ?
      VECTOR_SIZE - (offset + VECTOR_SIZE - Portal.GetNumberOfValues()) : VECTOR_SIZE;
    ValueType vector[VECTOR_SIZE];
    if (n_valid == VECTOR_SIZE){
      scan16(vector, offset);
    }
    else {
      scan_remainder(vector, offset, n_valid);
    }

    CarryOut = vector[n_valid - 1];

    // Add the carry in value to the scan and write our values
    // back into the portal
#pragma unroll
    for (vtkm::Id i = 0; i < VECTOR_SIZE; ++i){
      vector[i] = BinaryOperator(CarryIn, vector[i]);
    }

    store_vector(vector, offset, n_valid);
  }

private:

  VTKM_EXEC_EXPORT
  void scan16(ValueType *vector, const vtkm::Id offset) const {
    // Load up the elements for this 'vector'
#pragma unroll
    for (vtkm::Id i = 0; i < VECTOR_SIZE; ++i){
      vector[i] = Portal.Get(i + offset);
    }
    scan_vector(vector);
  }

  VTKM_EXEC_EXPORT
  void scan_remainder(ValueType *vector, const vtkm::Id offset, const vtkm::Id n_valid) const {
    // Load up the elements for this 'vector'
    for (vtkm::Id i = 0; i < n_valid; ++i){
      vector[i] = Portal.Get(i + offset);
    }
    scan_vector(vector);
  }

  VTKM_EXEC_EXPORT
  void store_vector(ValueType *vector, const vtkm::Id offset, const vtkm::Id n_valid) const {
    // Load up the elements for this 'vector'
    for (vtkm::Id i = 0; i < n_valid; ++i){
      Portal.Set(i + offset, vector[i]);
    }
  }

  // Perform an inclusive scan on a 16 element ValueType array
  VTKM_EXEC_EXPORT
  void scan_vector(ValueType *vector) const {
    // Outer loop iter 0
    // stride = 2
    vector[1]  = BinaryOperator(vector[0],  vector[1]);
    vector[3]  = BinaryOperator(vector[2],  vector[3]);
    vector[5]  = BinaryOperator(vector[4],  vector[5]);
    vector[7]  = BinaryOperator(vector[6],  vector[7]);
    vector[9]  = BinaryOperator(vector[8],  vector[9]);
    vector[11] = BinaryOperator(vector[10], vector[11]);
    vector[13] = BinaryOperator(vector[12], vector[13]);
    vector[15] = BinaryOperator(vector[14], vector[15]);

    // Outer loop iter 1
    // stride = 4
    vector[3]  = BinaryOperator(vector[1],  vector[3]);
    vector[7]  = BinaryOperator(vector[5],  vector[7]);
    vector[11] = BinaryOperator(vector[9],  vector[11]);
    vector[15] = BinaryOperator(vector[13], vector[15]);

    // Outer loop iter 2
    // stride = 8
    vector[7]  = BinaryOperator(vector[3],  vector[7]);
    vector[15] = BinaryOperator(vector[11],  vector[15]);

    // Outer loop iter 3
    // stride = 16
    vector[15] = BinaryOperator(vector[7], vector[15]);

    // DOWNWARD PASS
    // Outer loop iter 0
    // stride = 8
    vector[11] = BinaryOperator(vector[7], vector[11]);

    // Outer loop iter 1
    // stride = 4
    vector[5]  = BinaryOperator(vector[3],  vector[5]);
    vector[9]  = BinaryOperator(vector[7],  vector[9]);
    vector[13] = BinaryOperator(vector[11], vector[13]);

    // Outer loop iter 2
    // stride = 2
    vector[2]  = BinaryOperator(vector[1],  vector[2]);
    vector[4]  = BinaryOperator(vector[3],  vector[4]);
    vector[6]  = BinaryOperator(vector[5],  vector[6]);
    vector[8]  = BinaryOperator(vector[7],  vector[8]);
    vector[10] = BinaryOperator(vector[9],  vector[10]);
    vector[12] = BinaryOperator(vector[11], vector[12]);
    vector[14] = BinaryOperator(vector[13], vector[14]);
  }
};
}
}
}

#endif

