Commit 273281f5 authored by Sujin Philip's avatar Sujin Philip
Browse files

Fix Parallel Scan implementation for TBB Device

The previous implementation assumed the identity value to be zero, which does
not work for multiplication. Changed the interface to require an initial value
for Exclusive Scan with custom operator (TBB Device only, for now).
parent 6b70893b
......@@ -81,6 +81,7 @@ private:
typedef typename boost::remove_reference<
typename OutputPortalType::ValueType>::type ValueType;
ValueType Sum;
bool FirstCall;
InputPortalType InputPortal;
OutputPortalType OutputPortal;
BinaryOperationType BinaryOperation;
......@@ -90,6 +91,7 @@ private:
const OutputPortalType &outputPortal,
BinaryOperationType binaryOperation)
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
FirstCall(true),
InputPortal(inputPortal),
OutputPortal(outputPortal),
BinaryOperation(binaryOperation)
......@@ -98,6 +100,7 @@ private:
VTKM_EXEC_CONT_EXPORT
ScanInclusiveBody(const ScanInclusiveBody &body, ::tbb::split)
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
FirstCall(true),
InputPortal(body.InputPortal),
OutputPortal(body.OutputPortal),
BinaryOperation(body.BinaryOperation) { }
......@@ -110,10 +113,12 @@ private:
InputIteratorsType inputIterators(this->InputPortal);
//use temp, and iterators instead of member variable to reduce false sharing
ValueType temp = this->Sum;
typename InputIteratorsType::IteratorType inIter =
inputIterators.GetBegin() + range.begin();
for (vtkm::Id index = range.begin(); index != range.end();
ValueType temp = this->FirstCall ? *inIter++ :
this->BinaryOperation(this->Sum, *inIter++);
this->FirstCall = false;
for (vtkm::Id index = range.begin() + 1; index != range.end();
++index, ++inIter)
{
temp = this->BinaryOperation(temp, *inIter);
......@@ -133,12 +138,15 @@ private:
OutputIteratorsType outputIterators(this->OutputPortal);
//use temp, and iterators instead of member variable to reduce false sharing
ValueType temp = this->Sum;
typename InputIteratorsType::IteratorType inIter =
inputIterators.GetBegin() + range.begin();
typename OutputIteratorsType::IteratorType outIter =
outputIterators.GetBegin() + range.begin();
for (vtkm::Id index = range.begin(); index != range.end();
ValueType temp = this->FirstCall ? *inIter++ :
this->BinaryOperation(this->Sum, *inIter++);
this->FirstCall = false;
*outIter++ = temp;
for (vtkm::Id index = range.begin() + 1; index != range.end();
++index, ++inIter, ++outIter)
{
*outIter = temp = this->BinaryOperation(temp, *inIter);
......@@ -188,6 +196,7 @@ private:
typedef typename boost::remove_reference<
typename OutputPortalType::ValueType>::type ValueType;
ValueType Sum;
ValueType InitialValue;
InputPortalType InputPortal;
OutputPortalType OutputPortal;
BinaryOperationType BinaryOperation;
......@@ -195,8 +204,10 @@ private:
VTKM_CONT_EXPORT
ScanExclusiveBody(const InputPortalType &inputPortal,
const OutputPortalType &outputPortal,
BinaryOperationType binaryOperation)
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
BinaryOperationType binaryOperation,
const ValueType& initialValue)
: Sum(initialValue),
InitialValue(initialValue),
InputPortal(inputPortal),
OutputPortal(outputPortal),
BinaryOperation(binaryOperation)
......@@ -204,10 +215,12 @@ private:
VTKM_EXEC_CONT_EXPORT
ScanExclusiveBody(const ScanExclusiveBody &body, ::tbb::split)
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
: Sum(body.InitialValue),
InitialValue(body.InitialValue),
InputPortal(body.InputPortal),
OutputPortal(body.OutputPortal),
BinaryOperation(body.BinaryOperation) { }
BinaryOperation(body.BinaryOperation)
{ }
VTKM_EXEC_EXPORT
void operator()(const ::tbb::blocked_range<vtkm::Id> &range, ::tbb::pre_scan_tag)
......@@ -216,11 +229,10 @@ private:
InputIteratorsType;
InputIteratorsType inputIterators(this->InputPortal);
ValueType temp = this->Sum;
//move the iterator to the first item
typename InputIteratorsType::IteratorType iter =
inputIterators.GetBegin() + range.begin();
ValueType temp = this->Sum;
for (vtkm::Id index = range.begin(); index != range.end(); ++index, ++iter)
{
temp = this->BinaryOperation(temp, *iter);
......@@ -239,13 +251,12 @@ private:
InputIteratorsType inputIterators(this->InputPortal);
OutputIteratorsType outputIterators(this->OutputPortal);
ValueType temp = this->Sum;
//move the iterators to the first item
typename InputIteratorsType::IteratorType inIter =
inputIterators.GetBegin() + range.begin();
typename OutputIteratorsType::IteratorType outIter =
outputIterators.GetBegin() + range.begin();
ValueType temp = this->Sum;
for (vtkm::Id index = range.begin(); index != range.end();
++index, ++inIter, ++outIter)
{
......@@ -277,7 +288,9 @@ private:
typename boost::remove_reference<typename OutputPortalType::ValueType>::type
ScanExclusivePortals(InputPortalType inputPortal,
OutputPortalType outputPortal,
BinaryOperationType binaryOperation)
BinaryOperationType binaryOperation,
typename boost::remove_reference<
typename OutputPortalType::ValueType>::type initialValue)
{
typedef typename
boost::remove_reference<typename OutputPortalType::ValueType>::type
......@@ -287,7 +300,7 @@ private:
WrappedBinaryOp wrappedBinaryOp(binaryOperation);
ScanExclusiveBody<InputPortalType, OutputPortalType, WrappedBinaryOp>
body(inputPortal, outputPortal, wrappedBinaryOp);
body(inputPortal, outputPortal, wrappedBinaryOp, initialValue);
vtkm::Id arrayLength = inputPortal.GetNumberOfValues();
::tbb::parallel_scan( ::tbb::blocked_range<vtkm::Id>(0, arrayLength), body);
......@@ -331,19 +344,21 @@ public:
return ScanExclusivePortals(
input.PrepareForInput(vtkm::cont::DeviceAdapterTagTBB()),
output.PrepareForOutput(input.GetNumberOfValues(),
vtkm::cont::DeviceAdapterTagTBB()), vtkm::internal::Add());
vtkm::cont::DeviceAdapterTagTBB()),
vtkm::internal::Add(), vtkm::TypeTraits<T>::ZeroInitialization());
}
template<typename T, class CIn, class COut, class BinaryFunctor>
VTKM_CONT_EXPORT static T ScanExclusive(
const vtkm::cont::ArrayHandle<T,CIn> &input,
vtkm::cont::ArrayHandle<T,COut> &output,
BinaryFunctor binary_functor)
BinaryFunctor binary_functor,
const T& initialValue)
{
return ScanExclusivePortals(
input.PrepareForInput(vtkm::cont::DeviceAdapterTagTBB()),
output.PrepareForOutput(input.GetNumberOfValues(),
vtkm::cont::DeviceAdapterTagTBB()), binary_functor);
vtkm::cont::DeviceAdapterTagTBB()), binary_functor, initialValue);
}
private:
......
......@@ -1283,6 +1283,39 @@ private:
}
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Inclusive Scan with multiplication operator" << std::endl;
{
vtkm::Float64 inputValues[ARRAY_SIZE];
for (vtkm::Id i = 0; i < ARRAY_SIZE; ++i)
{
inputValues[i] = 1.01;
}
vtkm::Id mid = ARRAY_SIZE/2;
inputValues[mid] = 0.0;
vtkm::cont::ArrayHandle<vtkm::Float64> array = MakeArrayHandle(inputValues,
ARRAY_SIZE);
vtkm::Float64 product = Algorithm::ScanInclusive(array, array,
vtkm::internal::Multiply());
VTKM_TEST_ASSERT(product == 0.0f, "ScanInclusive product result not 0.0");
for (vtkm::Id i = 0; i < mid; ++i)
{
vtkm::Float64 expected = pow(1.01, static_cast<vtkm::Float64>(i + 1));
vtkm::Float64 got = array.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT(test_equal(got, expected),
"Incorrect results for ScanInclusive");
}
for (vtkm::Id i = mid; i < ARRAY_SIZE; ++i)
{
VTKM_TEST_ASSERT(array.GetPortalConstControl().Get(i) == 0.0f,
"Incorrect results for ScanInclusive");
}
}
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Inclusive Scan with a vtkm::Vec" << std::endl;
......@@ -1389,6 +1422,47 @@ private:
}
}
// Enable when Exclusive Scan with custom operator is implemented for all
// device adaptors
#if 0
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Exclusive Scan with multiplication operator" << std::endl;
{
vtkm::Float64 inputValues[ARRAY_SIZE];
for (vtkm::Id i = 0; i < ARRAY_SIZE; ++i)
{
inputValues[i] = 1.01;
}
vtkm::Id mid = ARRAY_SIZE/2;
inputValues[mid] = 0.0;
vtkm::cont::ArrayHandle<vtkm::Float64> array = MakeArrayHandle(inputValues,
ARRAY_SIZE);
vtkm::Float64 initialValue = 2.00;
vtkm::Float64 product = Algorithm::ScanExclusive(array, array,
vtkm::internal::Multiply(), initialValue);
VTKM_TEST_ASSERT(product == 0.0f, "ScanExclusive product result not 0.0");
VTKM_TEST_ASSERT(array.GetPortalConstControl().Get(0) == initialValue,
"ScanExclusive result's first value != initialValue");
for (vtkm::Id i = 1; i <= mid; ++i)
{
vtkm::Float64 expected = pow(1.01, static_cast<vtkm::Float64>(i)) *
initialValue;
vtkm::Float64 got = array.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT(test_equal(got, expected),
"Incorrect results for ScanExclusive");
}
for (vtkm::Id i = mid + 1; i < ARRAY_SIZE; ++i)
{
VTKM_TEST_ASSERT(array.GetPortalConstControl().Get(i) == 0.0f,
"Incorrect results for ScanExclusive");
}
}
#endif
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Exclusive Scan with a vtkm::Vec" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment