Commit 1ca55ac3 authored by Kenneth Moreland's avatar Kenneth Moreland

Add specialized operators for ArrayPortalValueReference

The ArrayPortalValueReference is supposed to behave just like the value
it encapsulates and does so by automatically converting to the base type
when necessary. However, when it is possible to convert that to
something else, it is possible to get errors about ambiguous overloads.
To avoid these, add specialized versions of the operators to specify
which ones should be used.

Also consolidated the CUDA version of an ArrayPortalValueReference to the
standard one. The two implementations were equivalent and we would like
changes to apply to both.
parent 6851077e
# Added specialized operators for ArrayPortalValueReference
The ArrayPortalValueReference is supposed to behave just like the value it
encapsulates and does so by automatically converting to the base type when
necessary. However, when it is possible to convert that to something else,
it is possible to get errors about ambiguous overloads. To avoid these, add
specialized versions of the operators to specify which ones should be used.
Also consolidated the CUDA version of an ArrayPortalValueReference to the
standard one. The two implementations were equivalent and we would like
changes to apply to both.
......@@ -38,7 +38,7 @@ namespace internal
// Binary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// IteratorFromArrayPortalValue which happen when passed an input array that
// ArrayPortalValueReference which happen when passed an input array that
// is implicit.
template <typename ResultType, typename Function>
struct WrappedBinaryOperator
......
......@@ -22,6 +22,7 @@
#include <vtkm/Pair.h>
#include <vtkm/Types.h>
#include <vtkm/internal/ArrayPortalValueReference.h>
#include <vtkm/internal/ExportMacros.h>
// Disable warnings we check vtkm for but Thrust does not.
......@@ -40,57 +41,13 @@ namespace cuda
namespace internal
{
template <class ArrayPortalType>
struct PortalValue
{
using ValueType = typename ArrayPortalType::ValueType;
VTKM_EXEC_CONT
PortalValue(const ArrayPortalType& portal, vtkm::Id index)
: Portal(portal)
, Index(index)
{
}
VTKM_EXEC
void Swap(PortalValue<ArrayPortalType>& rhs) throw()
{
//we need use the explicit type not a proxy temp object
//A proxy temp object would point to the same underlying data structure
//and would not hold the old value of *this once *this was set to rhs.
const ValueType aValue = *this;
*this = rhs;
rhs = aValue;
}
VTKM_EXEC
PortalValue<ArrayPortalType>& operator=(const PortalValue<ArrayPortalType>& rhs)
{
this->Portal.Set(this->Index, rhs.Portal.Get(rhs.Index));
return *this;
}
VTKM_EXEC
ValueType operator=(const ValueType& value) const
{
this->Portal.Set(this->Index, value);
return value;
}
VTKM_EXEC
operator ValueType(void) const { return this->Portal.Get(this->Index); }
const ArrayPortalType& Portal;
vtkm::Id Index;
};
template <class ArrayPortalType>
class IteratorFromArrayPortal
: public ::thrust::iterator_facade<IteratorFromArrayPortal<ArrayPortalType>,
typename ArrayPortalType::ValueType,
::thrust::system::cuda::tag,
::thrust::random_access_traversal_tag,
PortalValue<ArrayPortalType>,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType>,
std::ptrdiff_t>
{
public:
......@@ -109,9 +66,11 @@ public:
}
VTKM_EXEC
PortalValue<ArrayPortalType> operator[](std::ptrdiff_t idx) const //NEEDS to be signed
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> operator[](
std::ptrdiff_t idx) const //NEEDS to be signed
{
return PortalValue<ArrayPortalType>(this->Portal, this->Index + static_cast<vtkm::Id>(idx));
return vtkm::internal::ArrayPortalValueReference<ArrayPortalType>(
this->Portal, this->Index + static_cast<vtkm::Id>(idx));
}
private:
......@@ -122,9 +81,9 @@ private:
friend class ::thrust::iterator_core_access;
VTKM_EXEC
PortalValue<ArrayPortalType> dereference() const
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> dereference() const
{
return PortalValue<ArrayPortalType>(this->Portal, this->Index);
return vtkm::internal::ArrayPortalValueReference<ArrayPortalType>(this->Portal, this->Index);
}
VTKM_EXEC
......@@ -167,7 +126,8 @@ private:
//
//But for vtk-m we pass in facade objects, which are passed by value, but
//must be treated as references. So do to do that properly we need to specialize
//is_non_const_reference to state a PortalValue by value is valid for writing
//is_non_const_reference to state an ArrayPortalValueReference by value is valid
//for writing
namespace thrust
{
namespace detail
......@@ -177,7 +137,7 @@ template <typename T>
struct is_non_const_reference;
template <typename T>
struct is_non_const_reference<vtkm::exec::cuda::internal::PortalValue<T>>
struct is_non_const_reference<vtkm::internal::ArrayPortalValueReference<T>>
: thrust::detail::true_type
{
};
......
......@@ -42,7 +42,7 @@ namespace internal
// Unary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// PortalValue which happen when passed an input array that
// ArrayPortalValueReference which happen when passed an input array that
// is implicit.
template <typename T_, typename Function>
struct WrappedUnaryPredicate
......@@ -70,9 +70,9 @@ struct WrappedUnaryPredicate
VTKM_EXEC bool operator()(const T& x) const { return m_f(x); }
template <typename U>
VTKM_EXEC bool operator()(const PortalValue<U>& x) const
VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x) const
{
return m_f((T)x);
return m_f(x.Get());
}
VTKM_EXEC bool operator()(const T* x) const { return m_f(*x); }
......@@ -80,7 +80,7 @@ struct WrappedUnaryPredicate
// Binary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// PortalValue which happen when passed an input array that
// ArrayPortalValueReference which happen when passed an input array that
// is implicit.
template <typename T_, typename Function>
struct WrappedBinaryOperator
......@@ -109,27 +109,24 @@ struct WrappedBinaryOperator
VTKM_EXEC T operator()(const T& x, const T& y) const { return m_f(x, y); }
template <typename U>
VTKM_EXEC T operator()(const T& x, const PortalValue<U>& y) const
VTKM_EXEC T operator()(const T& x, const vtkm::internal::ArrayPortalValueReference<U>& y) const
{
// to support proper implicit conversion, and avoid overload
// ambiguities.
T conv_y = y;
return m_f(x, conv_y);
return m_f(x, y.Get());
}
template <typename U>
VTKM_EXEC T operator()(const PortalValue<U>& x, const T& y) const
VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference<U>& x, const T& y) const
{
T conv_x = x;
return m_f(conv_x, y);
return m_f(x.Get(), y);
}
template <typename U, typename V>
VTKM_EXEC T operator()(const PortalValue<U>& x, const PortalValue<V>& y) const
VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference<U>& x,
const vtkm::internal::ArrayPortalValueReference<V>& y) const
{
T conv_x = x;
T conv_y = y;
return m_f(conv_x, conv_y);
return m_f(x.Get(), y.Get());
}
VTKM_EXEC T operator()(const T* const x, const T& y) const { return m_f(*x, y); }
......@@ -166,21 +163,22 @@ struct WrappedBinaryPredicate
VTKM_EXEC bool operator()(const T& x, const T& y) const { return m_f(x, y); }
template <typename U>
VTKM_EXEC bool operator()(const T& x, const PortalValue<U>& y) const
VTKM_EXEC bool operator()(const T& x, const vtkm::internal::ArrayPortalValueReference<U>& y) const
{
return m_f(x, (T)y);
return m_f(x, y.Get());
}
template <typename U>
VTKM_EXEC bool operator()(const PortalValue<U>& x, const T& y) const
VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x, const T& y) const
{
return m_f((T)x, y);
return m_f(x.Get(), y);
}
template <typename U, typename V>
VTKM_EXEC bool operator()(const PortalValue<U>& x, const PortalValue<V>& y) const
VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x,
const vtkm::internal::ArrayPortalValueReference<V>& y) const
{
return m_f((T)x, (T)y);
return m_f(x.Get(), y.Get());
}
VTKM_EXEC bool operator()(const T* const x, const T& y) const { return m_f(*x, y); }
......
This diff is collapsed.
......@@ -22,11 +22,15 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/TypeTraits.h>
#include <vtkm/cont/testing/Testing.h>
namespace
{
static constexpr vtkm::Id ARRAY_SIZE = 10;
template <typename ArrayPortalType>
void SetReference(vtkm::Id index, vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref)
{
......@@ -41,7 +45,204 @@ void CheckReference(vtkm::Id index, vtkm::internal::ArrayPortalValueReference<Ar
VTKM_TEST_ASSERT(test_equal(ref, TestValue(index, ValueType())), "Got bad value from reference.");
}
static constexpr vtkm::Id ARRAY_SIZE = 10;
template <typename ArrayPortalType>
void TryOperatorsNoVec(vtkm::Id index,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref,
vtkm::TypeTraitsScalarTag)
{
using ValueType = typename ArrayPortalType::ValueType;
ValueType expected = TestValue(index, ValueType());
VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected.");
VTKM_TEST_ASSERT(!(ref < ref));
VTKM_TEST_ASSERT(ref < ValueType(expected + ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected - ValueType(1)) < ref);
VTKM_TEST_ASSERT(!(ref > ref));
VTKM_TEST_ASSERT(ref > ValueType(expected - ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected + ValueType(1)) > ref);
VTKM_TEST_ASSERT(ref <= ref);
VTKM_TEST_ASSERT(ref <= ValueType(expected + ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected - ValueType(1)) <= ref);
VTKM_TEST_ASSERT(ref >= ref);
VTKM_TEST_ASSERT(ref >= ValueType(expected - ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected + ValueType(1)) >= ref);
}
template <typename ArrayPortalType>
void TryOperatorsNoVec(vtkm::Id,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType>,
vtkm::TypeTraitsVectorTag)
{
}
template <typename ArrayPortalType>
void TryOperatorsInt(vtkm::Id index,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref,
vtkm::TypeTraitsScalarTag,
vtkm::TypeTraitsIntegerTag)
{
using ValueType = typename ArrayPortalType::ValueType;
const ValueType operand = TestValue(ARRAY_SIZE, ValueType());
ValueType expected = TestValue(index, ValueType());
VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected.");
VTKM_TEST_ASSERT((ref % ref) == (expected % expected));
VTKM_TEST_ASSERT((ref % expected) == (expected % expected));
VTKM_TEST_ASSERT((expected % ref) == (expected % expected));
VTKM_TEST_ASSERT((ref ^ ref) == (expected ^ expected));
VTKM_TEST_ASSERT((ref ^ expected) == (expected ^ expected));
VTKM_TEST_ASSERT((expected ^ ref) == (expected ^ expected));
VTKM_TEST_ASSERT((ref | ref) == (expected | expected));
VTKM_TEST_ASSERT((ref | expected) == (expected | expected));
VTKM_TEST_ASSERT((expected | ref) == (expected | expected));
VTKM_TEST_ASSERT((ref & ref) == (expected & expected));
VTKM_TEST_ASSERT((ref & expected) == (expected & expected));
VTKM_TEST_ASSERT((expected & ref) == (expected & expected));
VTKM_TEST_ASSERT((ref << ref) == (expected << expected));
VTKM_TEST_ASSERT((ref << expected) == (expected << expected));
VTKM_TEST_ASSERT((expected << ref) == (expected << expected));
VTKM_TEST_ASSERT((ref << ref) == (expected << expected));
VTKM_TEST_ASSERT((ref << expected) == (expected << expected));
VTKM_TEST_ASSERT((expected << ref) == (expected << expected));
VTKM_TEST_ASSERT(~ref == ~expected);
VTKM_TEST_ASSERT(!(!ref));
VTKM_TEST_ASSERT(ref && ref);
VTKM_TEST_ASSERT(ref && expected);
VTKM_TEST_ASSERT(expected && ref);
VTKM_TEST_ASSERT(ref || ref);
VTKM_TEST_ASSERT(ref || expected);
VTKM_TEST_ASSERT(expected || ref);
ref &= ref;
expected &= expected;
VTKM_TEST_ASSERT(ref == expected);
ref &= operand;
expected &= operand;
VTKM_TEST_ASSERT(ref == expected);
ref |= ref;
expected |= expected;
VTKM_TEST_ASSERT(ref == expected);
ref |= operand;
expected |= operand;
VTKM_TEST_ASSERT(ref == expected);
ref >>= ref;
expected >>= expected;
VTKM_TEST_ASSERT(ref == expected);
ref >>= operand;
expected >>= operand;
VTKM_TEST_ASSERT(ref == expected);
ref <<= ref;
expected <<= expected;
VTKM_TEST_ASSERT(ref == expected);
ref <<= operand;
expected <<= operand;
VTKM_TEST_ASSERT(ref == expected);
ref ^= ref;
expected ^= expected;
VTKM_TEST_ASSERT(ref == expected);
ref ^= operand;
expected ^= operand;
VTKM_TEST_ASSERT(ref == expected);
}
template <typename ArrayPortalType, typename DimTag, typename NumericTag>
void TryOperatorsInt(vtkm::Id,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType>,
DimTag,
NumericTag)
{
}
template <typename ArrayPortalType>
void TryOperators(vtkm::Id index, vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref)
{
using ValueType = typename ArrayPortalType::ValueType;
const ValueType operand = TestValue(ARRAY_SIZE, ValueType());
ValueType expected = TestValue(index, ValueType());
VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected.");
// Test comparison operators.
VTKM_TEST_ASSERT(ref == ref);
VTKM_TEST_ASSERT(ref == expected);
VTKM_TEST_ASSERT(expected == ref);
VTKM_TEST_ASSERT(!(ref != ref));
VTKM_TEST_ASSERT(!(ref != expected));
VTKM_TEST_ASSERT(!(expected != ref));
TryOperatorsNoVec(index, ref, typename vtkm::TypeTraits<ValueType>::DimensionalityTag());
VTKM_TEST_ASSERT((ref + ref) == (expected + expected));
VTKM_TEST_ASSERT((ref + expected) == (expected + expected));
VTKM_TEST_ASSERT((expected + ref) == (expected + expected));
VTKM_TEST_ASSERT((ref - ref) == (expected - expected));
VTKM_TEST_ASSERT((ref - expected) == (expected - expected));
VTKM_TEST_ASSERT((expected - ref) == (expected - expected));
VTKM_TEST_ASSERT((ref * ref) == (expected * expected));
VTKM_TEST_ASSERT((ref * expected) == (expected * expected));
VTKM_TEST_ASSERT((expected * ref) == (expected * expected));
VTKM_TEST_ASSERT((ref / ref) == (expected / expected));
VTKM_TEST_ASSERT((ref / expected) == (expected / expected));
VTKM_TEST_ASSERT((expected / ref) == (expected / expected));
ref += ref;
expected += expected;
VTKM_TEST_ASSERT(ref == expected);
ref += operand;
expected += operand;
VTKM_TEST_ASSERT(ref == expected);
ref -= ref;
expected -= expected;
VTKM_TEST_ASSERT(ref == expected);
ref -= operand;
expected -= operand;
VTKM_TEST_ASSERT(ref == expected);
ref *= ref;
expected *= expected;
VTKM_TEST_ASSERT(ref == expected);
ref *= operand;
expected *= operand;
VTKM_TEST_ASSERT(ref == expected);
ref /= ref;
expected /= expected;
VTKM_TEST_ASSERT(ref == expected);
ref /= operand;
expected /= operand;
VTKM_TEST_ASSERT(ref == expected);
// Reset ref
ref = TestValue(index, ValueType());
TryOperatorsInt(index,
ref,
typename vtkm::TypeTraits<ValueType>::DimensionalityTag(),
typename vtkm::TypeTraits<ValueType>::NumericTag());
}
struct DoTestForType
{
......@@ -54,7 +255,7 @@ struct DoTestForType
std::cout << "Set array using reference" << std::endl;
using PortalType = typename vtkm::cont::ArrayHandle<ValueType>::PortalControl;
PortalType portal = array.GetPortalControl();
for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
SetReference(index, vtkm::internal::ArrayPortalValueReference<PortalType>(portal, index));
}
......@@ -63,10 +264,17 @@ struct DoTestForType
CheckPortal(portal);
std::cout << "Check references in set array." << std::endl;
for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
CheckReference(index, vtkm::internal::ArrayPortalValueReference<PortalType>(portal, index));
}
std::cout << "Check that operators work." << std::endl;
// Start at 1 to avoid issues with 0.
for (vtkm::Id index = 1; index < ARRAY_SIZE; ++index)
{
TryOperators(index, vtkm::internal::ArrayPortalValueReference<PortalType>(portal, index));
}
}
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment