Commit d4adcb46 authored by Robert Maynard's avatar Robert Maynard

Merge branch 'reduce'

parents f514a9c7 e38caafe
......@@ -101,12 +101,42 @@ struct DeviceAdapterAlgorithm
const vtkm::cont::ArrayHandle<vtkm::Id,CIn>& input,
vtkm::cont::ArrayHandle<vtkm::Id,COut>& values_output);
/// \brief Compute a accumulated sum operation on the input ArrayHandle
///
/// Computes an accumulated sum on the \c input ArrayHandle, returning the
/// total sum. Reduce is similar to the stl accumulate sum function,
/// exception that Reduce doesn't do a serial summation. This means that if
/// you have defined a custom plus operator for T it must be commutative,
/// or you will get inconsistent results.
///
/// \return The total sum.
template<typename T, class CIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input,
T initialValue);
/// \brief Compute a accumulated sum operation on the input ArrayHandle
///
/// Computes an accumulated sum (or any user binary operation) on the
/// \c input ArrayHandle, returning the total sum. Reduce is
/// similar to the stl accumulate sum function, exception that Reduce
/// doesn't do a serial summation. This means that if you have defined a
/// custom plus operator for T it must be commutative, or you will get
/// inconsistent results.
///
/// \return The total sum.
template<typename T, class CIn, class BinaryOperation>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input,
T initialValue,
BinaryOperation binaryOp);
/// \brief Compute an inclusive prefix sum operation on the input ArrayHandle.
///
/// Computes an inclusive prefix sum operation on the \c input ArrayHandle,
/// storing the results in the \c output ArrayHandle. InclusiveScan is
/// similiar to the stl partial sum function, exception that InclusiveScan
/// doesn't do a serial sumnation. This means that if you have defined a
/// similar to the stl partial sum function, exception that InclusiveScan
/// doesn't do a serial summation. This means that if you have defined a
/// custom plus operator for T it must be associative, or you will get
/// inconsistent results. When the input and output ArrayHandles are the same
/// ArrayHandle the operation will be done inplace.
......@@ -118,12 +148,30 @@ struct DeviceAdapterAlgorithm
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output);
/// \brief Compute an inclusive prefix sum operation on the input ArrayHandle.
///
/// Computes an inclusive prefix sum operation on the \c input ArrayHandle,
/// storing the results in the \c output ArrayHandle. InclusiveScan is
/// similar to the stl partial sum function, exception that InclusiveScan
/// doesn't do a serial summation. This means that if you have defined a
/// custom plus operator for T it must be associative, or you will get
/// inconsistent results. When the input and output ArrayHandles are the same
/// ArrayHandle the operation will be done inplace.
///
/// \return The total sum.
///
template<typename T, class CIn, class COut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output,
BinaryOperation binaryOp);
/// \brief Compute an exclusive prefix sum operation on the input ArrayHandle.
///
/// Computes an exclusive prefix sum operation on the \c input ArrayHandle,
/// storing the results in the \c output ArrayHandle. ExclusiveScan is
/// similiar to the stl partial sum function, exception that ExclusiveScan
/// doesn't do a serial sumnation. This means that if you have defined a
/// similar to the stl partial sum function, exception that ExclusiveScan
/// doesn't do a serial summation. This means that if you have defined a
/// custom plus operator for T it must be associative, or you will get
/// inconsistent results. When the input and output ArrayHandles are the same
/// ArrayHandle the operation will be done inplace.
......
......@@ -290,6 +290,28 @@ private:
IteratorBegin(values_output));
}
template<class InputPortal>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ReducePortal(const InputPortal &input,
typename InputPortal::ValueType initialValue)
{
return ::thrust::reduce(IteratorBegin(input),
IteratorEnd(input),
initialValue);
}
template<class InputPortal, class BinaryOperation>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ReducePortal(const InputPortal &input,
typename InputPortal::ValueType initialValue,
BinaryOperation binaryOP)
{
return ::thrust::reduce(IteratorBegin(input),
IteratorEnd(input),
initialValue,
binaryOP);
}
template<class InputPortal, class OutputPortal>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ScanExclusivePortal(const InputPortal &input,
......@@ -320,6 +342,21 @@ private:
return *(IteratorEnd(output) - 1);
}
template<class InputPortal, class OutputPortal, class BinaryOperation>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ScanInclusivePortal(const InputPortal &input,
const OutputPortal &output,
BinaryOperation binaryOp)
{
::thrust::inclusive_scan(IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
binaryOp);
//return the value at the last index in the array, as that is the sum
return *(IteratorEnd(output) - 1);
}
template<class ValuesPortal>
VTKM_CONT_EXPORT static void SortPortal(const ValuesPortal &values)
{
......@@ -485,7 +522,7 @@ public:
const vtkm::cont::ArrayHandle<T,SIn> &input,
vtkm::cont::ArrayHandle<T,SOut> &output)
{
vtkm::Id numberOfValues = input.GetNumberOfValues();
const vtkm::Id numberOfValues = input.GetNumberOfValues();
CopyPortal(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
......@@ -525,12 +562,42 @@ public:
values_output.PrepareForInPlace(DeviceAdapterTag()));
}
template<typename T, class SIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,SIn> &input,
T initialValue)
{
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
return initialValue;
}
return ReducePortal(input.PrepareForInput( DeviceAdapterTag() ),
initialValue);
}
template<typename T, class SIn, class BinaryOperation>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,SIn> &input,
T initialValue,
BinaryOperation binaryOp)
{
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
return initialValue;
}
return ReducePortal(input.PrepareForInput( DeviceAdapterTag() ),
initialValue,
binaryOp);
}
template<typename T, class SIn, class SOut>
VTKM_CONT_EXPORT static T ScanExclusive(
const vtkm::cont::ArrayHandle<T,SIn> &input,
vtkm::cont::ArrayHandle<T,SOut>& output)
{
vtkm::Id numberOfValues = input.GetNumberOfValues();
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
......@@ -540,12 +607,13 @@ public:
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
template<typename T, class SIn, class SOut>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,SIn> &input,
vtkm::cont::ArrayHandle<T,SOut>& output)
{
vtkm::Id numberOfValues = input.GetNumberOfValues();
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
......@@ -556,6 +624,24 @@ public:
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
template<typename T, class SIn, class SOut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,SIn> &input,
vtkm::cont::ArrayHandle<T,SOut>& output,
BinaryOperation binaryOp)
{
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return 0;
}
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binaryOp);
}
// Because of some funny code conversions in nvcc, kernels for devices have to
// be public.
#ifndef VTKM_CUDA
......
......@@ -22,8 +22,9 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayPortalToIterators.h>
#include <vtkm/cont/ArrayHandleImplicit.h>
#include <vtkm/cont/ArrayHandleZip.h>
#include <vtkm/cont/ArrayPortalToIterators.h>
#include <vtkm/cont/StorageBasic.h>
#include <vtkm/exec/FunctorBase.h>
......@@ -316,6 +317,117 @@ public:
values_output);
}
//--------------------------------------------------------------------------
// Reduce
private:
template<int ReduceWidth, typename T, typename ArrayType, typename BinaryOperation >
struct ReduceKernel : vtkm::exec::FunctorBase
{
typedef typename ArrayType::template ExecutionTypes<
DeviceAdapterTag> ExecutionTypes;
typedef typename ExecutionTypes::PortalConst PortalConst;
PortalConst Portal;
BinaryOperation BinaryOperator;
vtkm::Id ArrayLength;
VTKM_CONT_EXPORT
ReduceKernel()
: Portal(),
BinaryOperator(),
ArrayLength(0)
{
}
VTKM_CONT_EXPORT
ReduceKernel(const ArrayType &array, BinaryOperation op)
: Portal(array.PrepareForInput( DeviceAdapterTag() ) ),
BinaryOperator(op),
ArrayLength( array.GetNumberOfValues() )
{ }
VTKM_EXEC_EXPORT
T operator()(vtkm::Id index) const
{
const vtkm::Id offset = index * ReduceWidth;
//at least the first value access to the portal will be valid
//only the rest could be invalid
T partialSum = this->Portal.Get( offset );
if( offset + ReduceWidth >= this->ArrayLength )
{
vtkm::Id currentIndex = offset + 1;
while( currentIndex < this->ArrayLength)
{
partialSum = BinaryOperator(partialSum, this->Portal.Get(currentIndex));
++currentIndex;
}
}
else
{
//optimize the usecase where all values are valid and we don't
//need to check that we might go out of bounds
for(int i=1; i < ReduceWidth; ++i)
{
partialSum = BinaryOperator(partialSum,
this->Portal.Get( offset + i )
);
}
}
return partialSum;
}
};
public:
template<typename T, class CIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input, T initialValue)
{
return DerivedAlgorithm::Reduce(input, initialValue, vtkm::internal::Add());
}
template<typename T, class CIn, class BinaryOperator>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input,
T initialValue,
BinaryOperator binaryOp)
{
//Crazy Idea:
//We create a implicit array handle that wraps the input
//array handle. The implicit functor is passed the input array handle, and
//the number of elements it needs to sum. This way the implicit handle
//acts as the first level reduction. Say for example reducing 16 values
//at a time.
//
//Now that we have an implicit array that is 1/16 the length of full array
//we can use scan inclusive to compute the final sum
typedef ReduceKernel<
16,
T,
vtkm::cont::ArrayHandle<T,CIn>,
BinaryOperator
> ReduceKernelType;
typedef vtkm::cont::ArrayHandleImplicit<
T,
ReduceKernelType > ReduceHandleType;
typedef vtkm::cont::ArrayHandle<
T,
vtkm::cont::StorageTagBasic> TempArrayType;
ReduceKernelType kernel(input, binaryOp);
vtkm::Id length = (input.GetNumberOfValues() / 16);
length += (input.GetNumberOfValues() % 16 == 0) ? 0 : 1;
ReduceHandleType reduced = vtkm::cont::make_ArrayHandleImplicit<T>(kernel,
length);
TempArrayType inclusiveScanStorage;
T scanResult = DerivedAlgorithm::ScanInclusive(reduced,
inclusiveScanStorage,
binaryOp);
return binaryOp(initialValue, scanResult);
}
//--------------------------------------------------------------------------
// Scan Exclusive
private:
......@@ -386,17 +498,20 @@ public:
//--------------------------------------------------------------------------
// Scan Inclusive
private:
template<typename PortalType>
template<typename PortalType, typename BinaryOperation>
struct ScanKernel : vtkm::exec::FunctorBase
{
PortalType Portal;
BinaryOperation BinaryOperator;
vtkm::Id Stride;
vtkm::Id Offset;
vtkm::Id Distance;
VTKM_CONT_EXPORT
ScanKernel(const PortalType &portal, vtkm::Id stride, vtkm::Id offset)
ScanKernel(const PortalType &portal, BinaryOperation binaryOp,
vtkm::Id stride, vtkm::Id offset)
: Portal(portal),
BinaryOperator(binaryOp),
Stride(stride),
Offset(offset),
Distance(stride/2)
......@@ -414,7 +529,7 @@ private:
{
ValueType leftValue = this->Portal.Get(leftIndex);
ValueType rightValue = this->Portal.Get(rightIndex);
this->Portal.Set(rightIndex, leftValue+rightValue);
this->Portal.Set(rightIndex, BinaryOperator(leftValue,rightValue) );
}
}
};
......@@ -424,17 +539,30 @@ public:
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output)
{
return DerivedAlgorithm::ScanInclusive(input,
output,
vtkm::internal::Add());
}
template<typename T, class CIn, class COut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output,
BinaryOperation binaryOp)
{
typedef typename
vtkm::cont::ArrayHandle<T,COut>
::template ExecutionTypes<DeviceAdapterTag>::Portal PortalType;
typedef ScanKernel<PortalType,BinaryOperation> ScanKernelType;
DerivedAlgorithm::Copy(input, output);
vtkm::Id numValues = output.GetNumberOfValues();
if (numValues < 1)
{
return 0;
return T(0);
}
PortalType portal = output.PrepareForInPlace(DeviceAdapterTag());
......@@ -442,14 +570,14 @@ public:
vtkm::Id stride;
for (stride = 2; stride-1 < numValues; stride *= 2)
{
ScanKernel<PortalType> kernel(portal, stride, stride/2 - 1);
ScanKernelType kernel(portal, binaryOp, stride, stride/2 - 1);
DerivedAlgorithm::Schedule(kernel, numValues/stride);
}
// Do reverse operation on odd indices. Start at stride we were just at.
for (stride /= 2; stride > 1; stride /= 2)
{
ScanKernel<PortalType> kernel(portal, stride, stride - 1);
ScanKernelType kernel(portal, binaryOp, stride, stride - 1);
DerivedAlgorithm::Schedule(kernel, numValues/stride);
}
......
......@@ -38,6 +38,61 @@
namespace vtkm {
namespace cont {
namespace internal
{
template<typename ResultType, typename Function>
struct WrappedBinaryOperator
{
Function m_f;
VTKM_CONT_EXPORT
WrappedBinaryOperator(const Function &f)
: m_f(f)
{}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(const Argument1 &x, const Argument2 &y) const
{
return m_f(x, y);
}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(
const detail::IteratorFromArrayPortalValue<Argument1> &x,
const detail::IteratorFromArrayPortalValue<Argument2> &y) const
{
typedef typename detail::IteratorFromArrayPortalValue<Argument1>::ValueType
ValueTypeX;
typedef typename detail::IteratorFromArrayPortalValue<Argument2>::ValueType
ValueTypeY;
return m_f( (ValueTypeX)x, (ValueTypeY)y );
}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(
const Argument1 &x,
const detail::IteratorFromArrayPortalValue<Argument2> &y) const
{
typedef typename detail::IteratorFromArrayPortalValue<Argument2>::ValueType
ValueTypeY;
return m_f( x, (ValueTypeY)y );
}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(
const detail::IteratorFromArrayPortalValue<Argument1> &x,
const Argument2 &y) const
{
typedef typename detail::IteratorFromArrayPortalValue<Argument1>::ValueType
ValueTypeX;
return m_f( (ValueTypeX)x, y );
}
};
}
template<>
struct DeviceAdapterAlgorithm<vtkm::cont::DeviceAdapterTagSerial> :
vtkm::cont::internal::DeviceAdapterAlgorithmGeneral<
......@@ -63,7 +118,7 @@ public:
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return 0; }
if (numberOfValues <= 0) { return T(0); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
......@@ -73,6 +128,41 @@ public:
return outputPortal.Get(numberOfValues - 1);
}
template<typename T, class CIn, class COut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output,
BinaryOperation binaryOp)
{
typedef typename vtkm::cont::ArrayHandle<T,COut>
::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T,CIn>
::template ExecutionTypes<Device>::PortalConst PortalIn;
//We need to wrap the operator in a WrappedBinaryOperator struct
//which can detect and handle calling the binary operator with complex
//value types such as IteratorFromArrayPortalValue which happen
//when passed an input array that is implicit. This occurs when
//invoking reduce which calls ScanInclusive
internal::WrappedBinaryOperator<T,BinaryOperation> wrappedBinaryOp(
binaryOp);
vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return T(0); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
vtkm::cont::ArrayPortalToIteratorBegin(outputPortal),
wrappedBinaryOp);
// Return the value at the last index in the array, which is the full sum.
return outputPortal.Get(numberOfValues - 1);
}
template<typename T, class CIn, class COut>
VTKM_CONT_EXPORT static T ScanExclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
......
......@@ -105,6 +105,16 @@ struct SortGreater
return valid;
}
};
struct MaxValue
{
template<typename T>
VTKM_EXEC_CONT_EXPORT T operator()(const T& a,const T& b) const
{
return (a > b) ? a : b;
}
};
}
......@@ -1051,11 +1061,61 @@ private:
VTKM_TEST_ASSERT(value == OFFSET, "Got bad unique value");
}
static VTKM_CONT_EXPORT void TestReduce()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Reduce" << std::endl;
//construct the index array
IdArrayHandle array;
Algorithm::Schedule(
ClearArrayKernel(array.PrepareForOutput(ARRAY_SIZE,
DeviceAdapterTag())),
ARRAY_SIZE);
//the output of reduce and scan inclusive should be the same
vtkm::Id reduce_sum = Algorithm::Reduce(array, vtkm::Id(0));
vtkm::Id reduce_sum_with_intial_value = Algorithm::Reduce(array,
vtkm::Id(ARRAY_SIZE));
vtkm::Id inclusive_sum = Algorithm::ScanInclusive(array, array);
VTKM_TEST_ASSERT(reduce_sum == OFFSET * ARRAY_SIZE,
"Got bad sum from Reduce");
VTKM_TEST_ASSERT(reduce_sum_with_intial_value == reduce_sum + ARRAY_SIZE,
"Got bad sum from Reduce with initial value");
VTKM_TEST_ASSERT(reduce_sum == inclusive_sum,
"Got different sums from Reduce and ScanInclusive");
}
static VTKM_CONT_EXPORT void TestReduceWithComparisonObject()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Reduce with comparison object " << std::endl;
//construct the index array. Assign an abnormally large value
//to the middle of the array, that should be what we see as our sum.
vtkm::Id testData[ARRAY_SIZE];
const vtkm::Id maxValue = ARRAY_SIZE*2;
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)