Commit 2ce14c77 authored by Robert Maynard's avatar Robert Maynard
Browse files

DeviceAdapter's now properly init all value types.

This is needed when use vtkm::Vec, as it doesn't zero the memory it contains.
parent 095f25c3
......@@ -25,6 +25,7 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ErrorExecution.h>
#include <vtkm/cont/Timer.h>
#include <vtkm/TypeTraits.h>
#include <vtkm/cont/cuda/internal/MakeThrustIterator.h>
......@@ -387,7 +388,7 @@ private:
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
vtkm::cont::internal::zeroinit::init<ValueType>(),
vtkm::TypeTraits<ValueType>::ZeroInitialization(),
bop);
//return the value at the last index in the array, as that is the sum
......@@ -692,7 +693,7 @@ public:
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return 0;
return vtkm::TypeTraits<T>::ZeroInitialization();
}
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
......@@ -708,7 +709,7 @@ public:
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return 0;
return vtkm::TypeTraits<T>::ZeroInitialization();
}
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
......@@ -725,7 +726,7 @@ public:
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return 0;
return vtkm::TypeTraits<T>::ZeroInitialization();
}
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
......
......@@ -20,6 +20,7 @@
#ifndef vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h
#define vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h
#include <vtkm/TypeTraits.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayHandleImplicit.h>
......@@ -765,7 +766,10 @@ public:
// Set first value in output (always 0).
DerivedAlgorithm::Schedule(
SetConstantKernel<DestPortalType>(destPortal,0), 1);
SetConstantKernel<DestPortalType>(
destPortal,
vtkm::TypeTraits<T>::ZeroInitialization()),
1);
// Shift remaining values over by one.
DerivedAlgorithm::Schedule(
CopyKernel<SrcPortalType,DestPortalType>(srcPortal,
......
......@@ -39,7 +39,6 @@
namespace vtkm {
namespace cont {
template<>
struct DeviceAdapterAlgorithm<vtkm::cont::DeviceAdapterTagSerial> :
vtkm::cont::internal::DeviceAdapterAlgorithmGeneral<
......@@ -152,7 +151,7 @@ public:
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return T(); }
if (numberOfValues <= 0) { return vtkm::TypeTraits<T>::ZeroInitialization(); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
......@@ -181,7 +180,7 @@ public:
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return T(); }
if (numberOfValues <= 0) { return vtkm::TypeTraits<T>::ZeroInitialization(); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
......@@ -207,7 +206,7 @@ public:
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) { return T(); }
if (numberOfValues <= 0) { return vtkm::TypeTraits<T>::ZeroInitialization(); }
std::partial_sum(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
......@@ -219,7 +218,7 @@ public:
std::copy_backward(vtkm::cont::ArrayPortalToIteratorBegin(outputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(outputPortal)-1,
vtkm::cont::ArrayPortalToIteratorEnd(outputPortal));
outputPortal.Set(0, 0);
outputPortal.Set(0, vtkm::TypeTraits<T>::ZeroInitialization());
return fullSum;
}
......
......@@ -89,13 +89,15 @@ private:
ScanInclusiveBody(const InputPortalType &inputPortal,
const OutputPortalType &outputPortal,
BinaryOperationType binaryOperation)
: Sum(), InputPortal(inputPortal), OutputPortal(outputPortal),
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
InputPortal(inputPortal),
OutputPortal(outputPortal),
BinaryOperation(binaryOperation)
{ }
VTKM_EXEC_CONT_EXPORT
ScanInclusiveBody(const ScanInclusiveBody &body, ::tbb::split)
: Sum(),
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
InputPortal(body.InputPortal),
OutputPortal(body.OutputPortal),
BinaryOperation(body.BinaryOperation) { }
......@@ -194,13 +196,15 @@ private:
ScanExclusiveBody(const InputPortalType &inputPortal,
const OutputPortalType &outputPortal,
BinaryOperationType binaryOperation)
: Sum(), InputPortal(inputPortal), OutputPortal(outputPortal),
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
InputPortal(inputPortal),
OutputPortal(outputPortal),
BinaryOperation(binaryOperation)
{ }
VTKM_EXEC_CONT_EXPORT
ScanExclusiveBody(const ScanExclusiveBody &body, ::tbb::split)
: Sum(),
: Sum( vtkm::TypeTraits<ValueType>::ZeroInitialization() ),
InputPortal(body.InputPortal),
OutputPortal(body.OutputPortal),
BinaryOperation(body.BinaryOperation) { }
......
......@@ -143,9 +143,6 @@ private:
typedef typename IdArrayHandle::template ExecutionTypes<DeviceAdapterTag>
::PortalConst IdPortalConstType;
typedef vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::FloatDefault,3>,StorageTag>
Vec3ArrayHandle;
typedef vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapterTag>
Algorithm;
......@@ -787,6 +784,9 @@ private:
std::cout << "Sort by keys" << std::endl;
typedef vtkm::Vec<FloatDefault,3> Vec3;
typedef vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::FloatDefault,3>,StorageTag>
Vec3ArrayHandle;
vtkm::Id testKeys[ARRAY_SIZE];
Vec3 testValues[ARRAY_SIZE];
......@@ -1183,6 +1183,8 @@ private:
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Inclusive Scan" << std::endl;
{
//construct the index array
IdArrayHandle array;
Algorithm::Schedule(
......@@ -1208,6 +1210,31 @@ private:
VTKM_TEST_ASSERT(partialSum == triangleNumber * OFFSET,
"Incorrect partial sum");
}
}
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Inclusive Scan with a vtkm::Vec" << std::endl;
{
typedef vtkm::Vec<Float64,3> Vec3;
typedef vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float64,3>,StorageTag>
Vec3ArrayHandle;
Vec3 testValues[ARRAY_SIZE];
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
testValues[i] = TestValue(i, Vec3());
}
Vec3ArrayHandle values = MakeArrayHandle(testValues, ARRAY_SIZE);
Vec3 sum = Algorithm::ScanInclusive(values, values);
std::cout << "Sum that was returned " << sum << std::endl;
VTKM_TEST_ASSERT( test_equal(sum, vtkm::make_Vec(6996.0,7996.0,8996.0) ),
"Got bad sum from Inclusive Scan");
}
}
static VTKM_CONT_EXPORT void TestScanInclusiveWithComparisonObject()
......@@ -1263,6 +1290,7 @@ private:
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Exclusive Scan" << std::endl;
{
//construct the index array
IdArrayHandle array;
Algorithm::Schedule(
......@@ -1273,7 +1301,7 @@ private:
// we know have an array whose sum = (OFFSET * ARRAY_SIZE),
// let's validate that
vtkm::Id sum = Algorithm::ScanExclusive(array, array);
std::cout << "Sum that was returned " << sum << std::endl;
VTKM_TEST_ASSERT(sum == (OFFSET * ARRAY_SIZE),
"Got bad sum from Exclusive Scan");
......@@ -1289,6 +1317,29 @@ private:
VTKM_TEST_ASSERT(partialSum == triangleNumber * OFFSET,
"Incorrect partial sum");
}
}
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Exclusive Scan with a vtkm::Vec" << std::endl;
{
typedef vtkm::Vec<Float64,3> Vec3;
typedef vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float64,3>,StorageTag>
Vec3ArrayHandle;
Vec3 testValues[ARRAY_SIZE];
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
testValues[i] = TestValue(i, Vec3());
}
Vec3ArrayHandle values = MakeArrayHandle(testValues, ARRAY_SIZE);
Vec3 sum = Algorithm::ScanExclusive(values, values);
std::cout << "Sum that was returned " << sum << std::endl;
VTKM_TEST_ASSERT( test_equal(sum, vtkm::make_Vec(6996.0,7996.0,8996.0) ),
"Got bad sum from Exclusive Scan");
}
}
static VTKM_CONT_EXPORT void TestErrorExecution()
......@@ -1367,11 +1418,11 @@ private:
TestReduceByKey();
TestScanExclusive();
TestScanInclusive();
TestScanInclusiveWithComparisonObject();
TestScanExclusive();
TestSort();
TestSortWithComparisonObject();
TestSortByKey();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment