Updates will be applied April 15th at 12pm EDT (UTC-0400). GitLab could be a little slow between 12 - 12:45pm EDT.

Commit 5d9f369d authored by Robert Maynard's avatar Robert Maynard

Adding ScanInclusive with custom binary operator to DeviceAdapterAlgorithm.

parent 9519737b
......@@ -133,6 +133,24 @@ struct DeviceAdapterAlgorithm
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output);
/// \brief Compute an inclusive prefix sum operation on the input ArrayHandle.
///
/// Computes an inclusive prefix sum operation on the \c input ArrayHandle,
/// storing the results in the \c output ArrayHandle. InclusiveScan is
/// similar to the stl partial sum function, exception that InclusiveScan
/// doesn't do a serial summation. This means that if you have defined a
/// custom plus operator for T it must be associative, or you will get
/// inconsistent results. When the input and output ArrayHandles are the same
/// ArrayHandle the operation will be done inplace.
///
/// \return The total sum.
///
template<typename T, class CIn, class COut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output,
BinaryOperation binaryOp);
/// \brief Compute an exclusive prefix sum operation on the input ArrayHandle.
///
/// Computes an exclusive prefix sum operation on the \c input ArrayHandle,
......
......@@ -330,6 +330,21 @@ private:
return *(IteratorEnd(output) - 1);
}
template<class InputPortal, class OutputPortal, class BinaryOperation>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ScanInclusivePortal(const InputPortal &input,
const OutputPortal &output,
BinaryOperation binaryOp)
{
::thrust::inclusive_scan(IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
binaryOp);
//return the value at the last index in the array, as that is the sum
return *(IteratorEnd(output) - 1);
}
template<class ValuesPortal>
VTKM_CONT_EXPORT static void SortPortal(const ValuesPortal &values)
{
......@@ -534,6 +549,7 @@ public:
LowerBoundsPortal(input.PrepareForInput(DeviceAdapterTag()),
values_output.PrepareForInPlace(DeviceAdapterTag()));
}
template<typename T, class SIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,SIn> &input, T initialValue)
......@@ -562,6 +578,7 @@ public:
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
template<typename T, class SIn, class SOut>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,SIn> &input,
......@@ -578,6 +595,24 @@ public:
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
template<typename T, class SIn, class SOut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,SIn> &input,
vtkm::cont::ArrayHandle<T,SOut>& output,
BinaryOperation binaryOp)
{
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return 0;
}
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binaryOp);
}
// Because of some funny code conversions in nvcc, kernels for devices have to
// be public.
#ifndef VTKM_CUDA
......
......@@ -488,17 +488,20 @@ public:
//--------------------------------------------------------------------------
// Scan Inclusive
private:
template<typename PortalType>
template<typename PortalType, typename BinaryOperation>
struct ScanKernel : vtkm::exec::FunctorBase
{
PortalType Portal;
BinaryOperation BinaryOperator;
vtkm::Id Stride;
vtkm::Id Offset;
vtkm::Id Distance;
VTKM_CONT_EXPORT
ScanKernel(const PortalType &portal, vtkm::Id stride, vtkm::Id offset)
ScanKernel(const PortalType &portal, BinaryOperation binaryOp,
vtkm::Id stride, vtkm::Id offset)
: Portal(portal),
BinaryOperator(binaryOp),
Stride(stride),
Offset(offset),
Distance(stride/2)
......@@ -516,7 +519,7 @@ private:
{
ValueType leftValue = this->Portal.Get(leftIndex);
ValueType rightValue = this->Portal.Get(rightIndex);
this->Portal.Set(rightIndex, leftValue+rightValue);
this->Portal.Set(rightIndex, BinaryOperator(leftValue,rightValue) );
}
}
};
......@@ -526,17 +529,30 @@ public:
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output)
{
return DerivedAlgorithm::ScanInclusive(input,
output,
vtkm::internal::Add());
}
template<typename T, class CIn, class COut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output,
BinaryOperation binaryOp)
{
typedef typename
vtkm::cont::ArrayHandle<T,COut>
::template ExecutionTypes<DeviceAdapterTag>::Portal PortalType;
typedef ScanKernel<PortalType,BinaryOperation> ScanKernelType;
DerivedAlgorithm::Copy(input, output);
vtkm::Id numValues = output.GetNumberOfValues();
if (numValues < 1)
{
return 0;
return T(0);
}
PortalType portal = output.PrepareForInPlace(DeviceAdapterTag());
......@@ -544,14 +560,14 @@ public:
vtkm::Id stride;
for (stride = 2; stride-1 < numValues; stride *= 2)
{
ScanKernel<PortalType> kernel(portal, stride, stride/2 - 1);
ScanKernelType kernel(portal, binaryOp, stride, stride/2 - 1);
DerivedAlgorithm::Schedule(kernel, numValues/stride);
}
// Do reverse operation on odd indices. Start at stride we were just at.
for (stride /= 2; stride > 1; stride /= 2)
{
ScanKernel<PortalType> kernel(portal, stride, stride - 1);
ScanKernelType kernel(portal, binaryOp, stride, stride - 1);
DerivedAlgorithm::Schedule(kernel, numValues/stride);
}
......
......@@ -63,7 +63,7 @@ public:
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return 0; }
if (numberOfValues <= 0) { return T(0); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
......@@ -73,6 +73,33 @@ public:
return outputPortal.Get(numberOfValues - 1);
}
template<typename T, class CIn, class COut, class BinaryOperation>
VTKM_CONT_EXPORT static T ScanInclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut>& output,
BinaryOperation binaryOp)
{
typedef typename vtkm::cont::ArrayHandle<T,COut>
::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T,CIn>
::template ExecutionTypes<Device>::PortalConst PortalIn;
vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return T(0); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
vtkm::cont::ArrayPortalToIteratorBegin(outputPortal),
binaryOp);
// Return the value at the last index in the array, which is the full sum.
return outputPortal.Get(numberOfValues - 1);
}
template<typename T, class CIn, class COut>
VTKM_CONT_EXPORT static T ScanExclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
......
......@@ -105,6 +105,16 @@ struct SortGreater
return valid;
}
};
struct MaxValue
{
template<typename T>
VTKM_EXEC_CONT_EXPORT T operator()(const T& a,const T& b) const
{
return (a > b) ? a : b;
}
};
}
......@@ -1112,6 +1122,54 @@ private:
}
}
static VTKM_CONT_EXPORT void TestScanInclusiveWithComparisonObject()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Inclusive Scan with comparison object " << std::endl;
//construct the index array
IdArrayHandle array;
Algorithm::Schedule(
ClearArrayKernel(array.PrepareForOutput(ARRAY_SIZE,
DeviceAdapterTag())),
ARRAY_SIZE);
Algorithm::Schedule(
AddArrayKernel(array.PrepareForOutput(ARRAY_SIZE,
DeviceAdapterTag())),
ARRAY_SIZE);
//we know have an array whose sum is equal to OFFSET * ARRAY_SIZE,
//let's validate that
IdArrayHandle result;
vtkm::Id sum = Algorithm::ScanInclusive(array,
result,
comparison::MaxValue());
VTKM_TEST_ASSERT(sum == OFFSET + (ARRAY_SIZE-1),
"Got bad sum from Inclusive Scan with comparison object");
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
const vtkm::Id input_value = array.GetPortalConstControl().Get(i);
const vtkm::Id result_value = result.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT(input_value == result_value, "Incorrect partial sum");
}
//now try it inline
sum = Algorithm::ScanInclusive(array,
array,
comparison::MaxValue());
VTKM_TEST_ASSERT(sum == OFFSET + (ARRAY_SIZE-1),
"Got bad sum from Inclusive Scan with comparison object");
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
const vtkm::Id input_value = array.GetPortalConstControl().Get(i);
const vtkm::Id result_value = result.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT(input_value == result_value, "Incorrect partial sum");
}
}
static VTKM_CONT_EXPORT void TestScanExclusive()
{
std::cout << "-------------------------------------------" << std::endl;
......@@ -1371,6 +1429,8 @@ private:
TestReduce();
TestScanInclusive();
TestScanInclusiveWithComparisonObject();
TestScanExclusive();
TestSort();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment