Commit e38caafe authored by Robert Maynard's avatar Robert Maynard

Adding Reduce with custom operator to the DeviceAdapterAlgorithm.

parent 5d9f369d
......@@ -101,6 +101,20 @@ struct DeviceAdapterAlgorithm
const vtkm::cont::ArrayHandle<vtkm::Id,CIn>& input,
vtkm::cont::ArrayHandle<vtkm::Id,COut>& values_output);
/// \brief Compute a accumulated sum operation on the input ArrayHandle
///
/// Computes an accumulated sum on the \c input ArrayHandle, returning the
/// total sum. Reduce is similar to the stl accumulate sum function,
/// exception that Reduce doesn't do a serial summation. This means that if
/// you have defined a custom plus operator for T it must be commutative,
/// or you will get inconsistent results.
///
/// \return The total sum.
template<typename T, class CIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input,
T initialValue);
/// \brief Compute a accumulated sum operation on the input ArrayHandle
///
/// Computes an accumulated sum (or any user binary operation) on the
......@@ -111,10 +125,11 @@ struct DeviceAdapterAlgorithm
/// inconsistent results.
///
/// \return The total sum.
template<typename T, class CIn>
template<typename T, class CIn, class BinaryOperation>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input,
T initialValue);
T initialValue,
BinaryOperation binaryOp);
/// \brief Compute an inclusive prefix sum operation on the input ArrayHandle.
///
......
......@@ -300,6 +300,18 @@ private:
initialValue);
}
template<class InputPortal, class BinaryOperation>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ReducePortal(const InputPortal &input,
typename InputPortal::ValueType initialValue,
BinaryOperation binaryOP)
{
return ::thrust::reduce(IteratorBegin(input),
IteratorEnd(input),
initialValue,
binaryOP);
}
template<class InputPortal, class OutputPortal>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ScanExclusivePortal(const InputPortal &input,
......@@ -552,7 +564,8 @@ public:
template<typename T, class SIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,SIn> &input, T initialValue)
const vtkm::cont::ArrayHandle<T,SIn> &input,
T initialValue)
{
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
......@@ -563,6 +576,22 @@ public:
initialValue);
}
template<typename T, class SIn, class BinaryOperation>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,SIn> &input,
T initialValue,
BinaryOperation binaryOp)
{
const vtkm::Id numberOfValues = input.GetNumberOfValues();
if (numberOfValues <= 0)
{
return initialValue;
}
return ReducePortal(input.PrepareForInput( DeviceAdapterTag() ),
initialValue,
binaryOp);
}
template<typename T, class SIn, class SOut>
VTKM_CONT_EXPORT static T ScanExclusive(
const vtkm::cont::ArrayHandle<T,SIn> &input,
......
......@@ -379,10 +379,18 @@ private:
}
};
public:
template<typename T, class CIn>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input, T initialValue)
{
return DerivedAlgorithm::Reduce(input, initialValue, vtkm::internal::Add());
}
template<typename T, class CIn, class BinaryOperator>
VTKM_CONT_EXPORT static T Reduce(
const vtkm::cont::ArrayHandle<T,CIn> &input,
T initialValue,
BinaryOperator binaryOp)
{
//Crazy Idea:
//We create a implicit array handle that wraps the input
......@@ -397,7 +405,7 @@ public:
16,
T,
vtkm::cont::ArrayHandle<T,CIn>,
vtkm::internal::Add
BinaryOperator
> ReduceKernelType;
typedef vtkm::cont::ArrayHandleImplicit<
......@@ -407,15 +415,17 @@ public:
T,
vtkm::cont::StorageTagBasic> TempArrayType;
ReduceKernelType kernel(input, vtkm::internal::Add());
ReduceKernelType kernel(input, binaryOp);
vtkm::Id length = (input.GetNumberOfValues() / 16);
length += (input.GetNumberOfValues() % 16 == 0) ? 0 : 1;
ReduceHandleType reduced = vtkm::cont::make_ArrayHandleImplicit<T>(kernel,
length);
TempArrayType inclusiveScan;
return initialValue + DerivedAlgorithm::ScanInclusive(reduced,
inclusiveScan);
TempArrayType inclusiveScanStorage;
T scanResult = DerivedAlgorithm::ScanInclusive(reduced,
inclusiveScanStorage,
binaryOp);
return binaryOp(initialValue, scanResult);
}
//--------------------------------------------------------------------------
......
......@@ -38,6 +38,61 @@
namespace vtkm {
namespace cont {
namespace internal
{
template<typename ResultType, typename Function>
struct WrappedBinaryOperator
{
Function m_f;
VTKM_CONT_EXPORT
WrappedBinaryOperator(const Function &f)
: m_f(f)
{}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(const Argument1 &x, const Argument2 &y) const
{
return m_f(x, y);
}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(
const detail::IteratorFromArrayPortalValue<Argument1> &x,
const detail::IteratorFromArrayPortalValue<Argument2> &y) const
{
typedef typename detail::IteratorFromArrayPortalValue<Argument1>::ValueType
ValueTypeX;
typedef typename detail::IteratorFromArrayPortalValue<Argument2>::ValueType
ValueTypeY;
return m_f( (ValueTypeX)x, (ValueTypeY)y );
}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(
const Argument1 &x,
const detail::IteratorFromArrayPortalValue<Argument2> &y) const
{
typedef typename detail::IteratorFromArrayPortalValue<Argument2>::ValueType
ValueTypeY;
return m_f( x, (ValueTypeY)y );
}
template<typename Argument1, typename Argument2>
VTKM_CONT_EXPORT ResultType operator()(
const detail::IteratorFromArrayPortalValue<Argument1> &x,
const Argument2 &y) const
{
typedef typename detail::IteratorFromArrayPortalValue<Argument1>::ValueType
ValueTypeX;
return m_f( (ValueTypeX)x, y );
}
};
}
template<>
struct DeviceAdapterAlgorithm<vtkm::cont::DeviceAdapterTagSerial> :
vtkm::cont::internal::DeviceAdapterAlgorithmGeneral<
......@@ -84,6 +139,14 @@ public:
typedef typename vtkm::cont::ArrayHandle<T,CIn>
::template ExecutionTypes<Device>::PortalConst PortalIn;
//We need to wrap the operator in a WrappedBinaryOperator struct
//which can detect and handle calling the binary operator with complex
//value types such as IteratorFromArrayPortalValue which happen
//when passed an input array that is implicit. This occurs when
//invoking reduce which calls ScanInclusive
internal::WrappedBinaryOperator<T,BinaryOperation> wrappedBinaryOp(
binaryOp);
vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device());
......@@ -92,9 +155,9 @@ public:
if (numberOfValues <= 0) { return T(0); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
vtkm::cont::ArrayPortalToIteratorBegin(outputPortal),
binaryOp);
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
vtkm::cont::ArrayPortalToIteratorBegin(outputPortal),
wrappedBinaryOp);
// Return the value at the last index in the array, which is the full sum.
return outputPortal.Get(numberOfValues - 1);
......
......@@ -1090,11 +1090,34 @@ private:
"Got different sums from Reduce and ScanInclusive");
}
static VTKM_CONT_EXPORT void TestReduceWithComparisonObject()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Reduce with comparison object " << std::endl;
//construct the index array. Assign an abnormally large value
//to the middle of the array, that should be what we see as our sum.
vtkm::Id testData[ARRAY_SIZE];
const vtkm::Id maxValue = ARRAY_SIZE*2;
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
testData[i]= i;
}
testData[ARRAY_SIZE/2] = maxValue;
IdArrayHandle input = MakeArrayHandle(testData, ARRAY_SIZE);
vtkm::Id largestValue = Algorithm::Reduce(input,
vtkm::Id(),
comparison::MaxValue());
VTKM_TEST_ASSERT(largestValue == maxValue,
"Got bad value from Reduce with comparison object");
}
static VTKM_CONT_EXPORT void TestScanInclusive()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Inclusive Scan" << std::endl;
//construct the index array
IdArrayHandle array;
Algorithm::Schedule(
......@@ -1427,6 +1450,7 @@ private:
TestErrorExecution();
TestReduce();
TestReduceWithComparisonObject();
TestScanInclusive();
TestScanInclusiveWithComparisonObject();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment