Commit 2c91cdfa authored by Robert Maynard's avatar Robert Maynard

Update cuda/thrust backend scan algorithms to work with vec types.

parent 19121fbb
......@@ -257,15 +257,21 @@ private:
template<class InputPortal, class ValuesPortal, class OutputPortal>
VTKM_CONT_EXPORT static void LowerBoundsPortal(const InputPortal &input,
const ValuesPortal &values,
const OutputPortal &output)
const ValuesPortal &values,
const OutputPortal &output)
{
::thrust::lower_bound(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(values),
IteratorEnd(values),
IteratorBegin(output));
typedef typename ValuesPortal::ValueType ValueType;
LowerBoundsPortal(input, values, output, ::thrust::less<ValueType>() );
}
template<class InputPortal, class OutputPortal>
VTKM_CONT_EXPORT static
void LowerBoundsPortal(const InputPortal &input,
const OutputPortal &values_output)
{
typedef typename InputPortal::ValueType ValueType;
LowerBoundsPortal(input, values_output, values_output,
::thrust::less<ValueType>() );
}
template<class InputPortal, class ValuesPortal, class OutputPortal,
......@@ -275,6 +281,7 @@ private:
const OutputPortal &output,
Compare comp)
{
vtkm::exec::cuda::internal::WrappedBinaryOperator<bool, Compare> bop(comp);
::thrust::lower_bound(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
......@@ -284,19 +291,6 @@ private:
comp);
}
template<class InputPortal, class OutputPortal>
VTKM_CONT_EXPORT static
void LowerBoundsPortal(const InputPortal &input,
const OutputPortal &values_output)
{
::thrust::lower_bound(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(values_output),
IteratorEnd(values_output),
IteratorBegin(values_output));
}
template<class InputPortal>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ReducePortal(const InputPortal &input,
......@@ -364,18 +358,40 @@ private:
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ScanExclusivePortal(const InputPortal &input,
const OutputPortal &output)
{
typedef typename InputPortal::ValueType ValueType;
return ScanExclusivePortal(input,
output,
(::thrust::plus<ValueType>()) );
}
template<class InputPortal, class OutputPortal, class BinaryOperation>
VTKM_CONT_EXPORT static
typename InputPortal::ValueType ScanExclusivePortal(const InputPortal &input,
const OutputPortal &output,
BinaryOperation binaryOp)
{
// Use iterator to get value so that thrust device_ptr has chance to handle
// data on device.
typename InputPortal::ValueType inputEnd = *(IteratorEnd(input) - 1);
typedef typename InputPortal::ValueType ValueType;
ValueType inputEnd = *(IteratorEnd(input) - 1);
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType,
BinaryOperation> bop(binaryOp);
::thrust::exclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output));
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType
IteratorType;
IteratorType end = ::thrust::exclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
vtkm::cont::internal::zeroinit::init<ValueType>(),
bop);
//return the value at the last index in the array, as that is the sum
return *(IteratorEnd(output) - 1) + inputEnd;
return binaryOp( *(end-1), inputEnd);
}
template<class InputPortal, class OutputPortal>
......@@ -383,13 +399,8 @@ private:
typename InputPortal::ValueType ScanInclusivePortal(const InputPortal &input,
const OutputPortal &output)
{
::thrust::inclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output));
//return the value at the last index in the array, as that is the sum
return *(IteratorEnd(output) - 1);
typedef typename InputPortal::ValueType ValueType;
return ScanInclusivePortal(input, output, ::thrust::plus<ValueType>() );
}
template<class InputPortal, class OutputPortal, class BinaryOperation>
......@@ -398,14 +409,20 @@ private:
const OutputPortal &output,
BinaryOperation binaryOp)
{
::thrust::inclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
binaryOp);
vtkm::exec::cuda::internal::WrappedBinaryOperator<typename InputPortal::ValueType,
BinaryOperation> bop(binaryOp);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType
IteratorType;
IteratorType end = ::thrust::inclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
bop);
//return the value at the last index in the array, as that is the sum
return *(IteratorEnd(output) - 1);
return *(end-1);
}
template<class ValuesPortal>
......
......@@ -25,6 +25,20 @@
#include <vtkm/internal/ExportMacros.h>
#include <vtkm/exec/cuda/internal/IteratorFromArrayPortal.h>
// Disable warnings we check vtkm for but Thrust does not.
#if defined(__GNUC__) || defined(____clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wconversion"
#endif // gcc || clang
#include <thrust/system/cuda/memory.h>
#if defined(__GNUC__) || defined(____clang__)
#pragma GCC diagnostic pop
#endif // gcc || clang
namespace vtkm {
namespace exec {
namespace cuda {
......@@ -39,6 +53,11 @@ template<typename ResultType, typename Function>
{
Function m_f;
VTKM_EXEC_EXPORT
WrappedBinaryOperator()
: m_f()
{}
VTKM_CONT_EXPORT
WrappedBinaryOperator(const Function &f)
: m_f(f)
......@@ -75,6 +94,55 @@ template<typename ResultType, typename Function>
return m_f((ValueTypeT)x, (ValueTypeU)y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const thrust::system::cuda::reference<T> &x,
const T &y) const
{
return m_f(*x, y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const T &x,
const thrust::system::cuda::reference<T> &y) const
{
return m_f(x, *y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const thrust::system::cuda::reference<T> &x,
const thrust::system::cuda::reference<T> &y) const
{
return m_f(*x, *y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const thrust::system::cuda::pointer<T> x,
const T* y) const
{
return m_f(*x, *y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const thrust::system::cuda::pointer<T> x,
const T& y) const
{
return m_f(*x, y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const T& x,
const thrust::system::cuda::pointer<T> y) const
{
return m_f(x, *y);
}
template<typename T>
VTKM_EXEC_EXPORT ResultType operator()(const thrust::system::cuda::pointer<T> x,
const thrust::system::cuda::pointer<T> y) const
{
return m_f(*x, *y);
}
};
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment