Commit a8415d8e authored by Robert Maynard's avatar Robert Maynard
Browse files

VTK-m now widens result type for UInt8/Int8/UInt16/Int16 input.

When using vtkm::dot on narrow types you easily rollover the values.
Instead the result type of vtkm::dot should be wide enough to store the results
(32bits) when this occurs.

Fixes #193
parent de7162ab
...@@ -516,22 +516,6 @@ public: ...@@ -516,22 +516,6 @@ public:
VTKM_EXEC_CONT VTKM_EXEC_CONT
bool operator!=(const DerivedClass& other) const { return !(this->operator==(other)); } bool operator!=(const DerivedClass& other) const { return !(this->operator==(other)); }
VTKM_EXEC_CONT
ComponentType Dot(const VecBaseCommon<ComponentType, DerivedClass>& other) const
{
// Why the static_cast here and below? Because * on small integers (char,
// short) promotes the result to a 32-bit int. After helpfully promoting
// the width of the result, some compilers then warn you about casting it
// back to the type you were expecting in the first place. The static_cast
// suppresses this warning.
ComponentType result = static_cast<ComponentType>(this->Component(0) * other.Component(0));
for (vtkm::IdComponent i = 1; i < this->NumComponents(); ++i)
{
result = static_cast<ComponentType>(result + this->Component(i) * other.Component(i));
}
return result;
}
#if (!(defined(VTKM_CUDA) && (__CUDACC_VER_MAJOR__ < 8))) #if (!(defined(VTKM_CUDA) && (__CUDACC_VER_MAJOR__ < 8)))
#if (defined(VTKM_GCC) || defined(VTKM_CLANG)) #if (defined(VTKM_GCC) || defined(VTKM_CLANG))
#pragma GCC diagnostic push #pragma GCC diagnostic push
...@@ -1241,46 +1225,85 @@ VTKM_EXEC_CONT static inline vtkm::VecCConst<T> make_VecC(const T* array, vtkm:: ...@@ -1241,46 +1225,85 @@ VTKM_EXEC_CONT static inline vtkm::VecCConst<T> make_VecC(const T* array, vtkm::
return vtkm::VecCConst<T>(array, size); return vtkm::VecCConst<T>(array, size);
} }
// A pre-declaration of vtkm::Pair so that classes templated on them can refer namespace detail
// to it. The actual implementation is in vtkm/Pair.h. {
template <typename U, typename V> template <typename T>
struct Pair; struct DotType
{
//results when < 32bit can be float if somehow we are using float16/float8, otherwise is
// int32 or uint32 depending on if it signed or not.
using float_type = vtkm::Float32;
using integer_type =
typename std::conditional<std::is_signed<T>::value, vtkm::Int32, vtkm::UInt32>::type;
using promote_type =
typename std::conditional<std::is_integral<T>::value, integer_type, float_type>::type;
using type =
typename std::conditional<(sizeof(T) < sizeof(vtkm::Float32)), promote_type, T>::type;
};
template <typename T>
static inline VTKM_EXEC_CONT typename DotType<typename T::ComponentType>::type vec_dot(const T& a,
const T& b)
{
using U = typename DotType<typename T::ComponentType>::type;
U result = a[0] * b[0];
for (vtkm::IdComponent i = 1; i < a.GetNumberOfComponents(); ++i)
{
result = result + a[i] * b[i];
}
return result;
}
template <typename T, vtkm::IdComponent Size> template <typename T, vtkm::IdComponent Size>
static inline VTKM_EXEC_CONT T dot(const vtkm::Vec<T, Size>& a, const vtkm::Vec<T, Size>& b) static inline VTKM_EXEC_CONT typename DotType<T>::type vec_dot(const vtkm::Vec<T, Size>& a,
const vtkm::Vec<T, Size>& b)
{ {
T result = T(a[0] * b[0]); using U = typename DotType<T>::type;
U result = a[0] * b[0];
for (vtkm::IdComponent i = 1; i < Size; ++i) for (vtkm::IdComponent i = 1; i < Size; ++i)
{ {
result = T(result + a[i] * b[i]); result = result + a[i] * b[i];
} }
return result; return result;
} }
}
template <typename T> template <typename T>
static inline VTKM_EXEC_CONT T dot(const vtkm::Vec<T, 2>& a, const vtkm::Vec<T, 2>& b) static inline VTKM_EXEC_CONT auto dot(const T& a, const T& b) -> decltype(detail::vec_dot(a, b))
{ {
return T((a[0] * b[0]) + (a[1] * b[1])); return detail::vec_dot(a, b);
} }
template <typename T> template <typename T>
static inline VTKM_EXEC_CONT T dot(const vtkm::Vec<T, 3>& a, const vtkm::Vec<T, 3>& b) static inline VTKM_EXEC_CONT typename detail::DotType<T>::type dot(const vtkm::Vec<T, 2>& a,
const vtkm::Vec<T, 2>& b)
{ {
return T((a[0] * b[0]) + (a[1] * b[1]) + (a[2] * b[2])); return (a[0] * b[0]) + (a[1] * b[1]);
} }
template <typename T> template <typename T>
static inline VTKM_EXEC_CONT T dot(const vtkm::Vec<T, 4>& a, const vtkm::Vec<T, 4>& b) static inline VTKM_EXEC_CONT typename detail::DotType<T>::type dot(const vtkm::Vec<T, 3>& a,
const vtkm::Vec<T, 3>& b)
{ {
return T((a[0] * b[0]) + (a[1] * b[1]) + (a[2] * b[2]) + (a[3] * b[3])); return (a[0] * b[0]) + (a[1] * b[1]) + (a[2] * b[2]);
} }
template <typename T>
template <typename T, typename VecType> static inline VTKM_EXEC_CONT typename detail::DotType<T>::type dot(const vtkm::Vec<T, 4>& a,
static inline VTKM_EXEC_CONT T dot(const vtkm::detail::VecBaseCommon<T, VecType>& a, const vtkm::Vec<T, 4>& b)
const vtkm::detail::VecBaseCommon<T, VecType>& b)
{ {
return a.Dot(b); return (a[0] * b[0]) + (a[1] * b[1]) + (a[2] * b[2]) + (a[3] * b[3]);
} }
// Integer types of a width less than an integer get implicitly casted to
// an integer when doing a multiplication.
#define VTK_M_SCALAR_DOT(stype) \
static inline VTKM_EXEC_CONT detail::DotType<stype>::type dot(stype a, stype b) { return a * b; }
VTK_M_SCALAR_DOT(vtkm::Int8)
VTK_M_SCALAR_DOT(vtkm::UInt8)
VTK_M_SCALAR_DOT(vtkm::Int16)
VTK_M_SCALAR_DOT(vtkm::UInt16)
VTK_M_SCALAR_DOT(vtkm::Int32)
VTK_M_SCALAR_DOT(vtkm::UInt32)
VTK_M_SCALAR_DOT(vtkm::Int64)
VTK_M_SCALAR_DOT(vtkm::UInt64)
VTK_M_SCALAR_DOT(vtkm::Float32)
VTK_M_SCALAR_DOT(vtkm::Float64)
template <typename T, vtkm::IdComponent Size> template <typename T, vtkm::IdComponent Size>
VTKM_EXEC_CONT T ReduceSum(const vtkm::Vec<T, Size>& a) VTKM_EXEC_CONT T ReduceSum(const vtkm::Vec<T, Size>& a)
...@@ -1340,22 +1363,10 @@ VTKM_EXEC_CONT T ReduceProduct(const vtkm::Vec<T, 4>& a) ...@@ -1340,22 +1363,10 @@ VTKM_EXEC_CONT T ReduceProduct(const vtkm::Vec<T, 4>& a)
return a[0] * a[1] * a[2] * a[3]; return a[0] * a[1] * a[2] * a[3];
} }
// Integer types of a width less than an integer get implicitly casted to // A pre-declaration of vtkm::Pair so that classes templated on them can refer
// an integer when doing a multiplication. // to it. The actual implementation is in vtkm/Pair.h.
#define VTK_M_INTEGER_PROMOTION_SCALAR_DOT(type) \ template <typename U, typename V>
static inline VTKM_EXEC_CONT type dot(type a, type b) { return static_cast<type>(a * b); } struct Pair;
VTK_M_INTEGER_PROMOTION_SCALAR_DOT(vtkm::Int8)
VTK_M_INTEGER_PROMOTION_SCALAR_DOT(vtkm::UInt8)
VTK_M_INTEGER_PROMOTION_SCALAR_DOT(vtkm::Int16)
VTK_M_INTEGER_PROMOTION_SCALAR_DOT(vtkm::UInt16)
#define VTK_M_SCALAR_DOT(type) \
static inline VTKM_EXEC_CONT type dot(type a, type b) { return a * b; }
VTK_M_SCALAR_DOT(vtkm::Int32)
VTK_M_SCALAR_DOT(vtkm::UInt32)
VTK_M_SCALAR_DOT(vtkm::Int64)
VTK_M_SCALAR_DOT(vtkm::UInt64)
VTK_M_SCALAR_DOT(vtkm::Float32)
VTK_M_SCALAR_DOT(vtkm::Float64)
} // End of namespace vtkm } // End of namespace vtkm
......
...@@ -37,11 +37,10 @@ namespace ...@@ -37,11 +37,10 @@ namespace
const vtkm::Id ARRAY_SIZE = 10; const vtkm::Id ARRAY_SIZE = 10;
template <typename ValueType>
struct MySquare struct MySquare
{ {
template <typename U> template <typename U>
VTKM_EXEC ValueType operator()(U u) const VTKM_EXEC auto operator()(U u) const -> decltype(vtkm::dot(u, u))
{ {
return vtkm::dot(u, u); return vtkm::dot(u, u);
} }
...@@ -59,7 +58,7 @@ struct CheckTransformFunctor : vtkm::exec::FunctorBase ...@@ -59,7 +58,7 @@ struct CheckTransformFunctor : vtkm::exec::FunctorBase
using T = typename TransformedPortalType::ValueType; using T = typename TransformedPortalType::ValueType;
typename OriginalPortalType::ValueType original = this->OriginalPortal.Get(index); typename OriginalPortalType::ValueType original = this->OriginalPortal.Get(index);
T transformed = this->TransformedPortal.Get(index); T transformed = this->TransformedPortal.Get(index);
if (!test_equal(transformed, MySquare<T>()(original))) if (!test_equal(transformed, MySquare{}(original)))
{ {
this->RaiseError("Encountered bad transformed value."); this->RaiseError("Encountered bad transformed value.");
} }
...@@ -107,7 +106,7 @@ VTKM_CONT void CheckControlPortals(const OriginalArrayHandleType& originalArray, ...@@ -107,7 +106,7 @@ VTKM_CONT void CheckControlPortals(const OriginalArrayHandleType& originalArray,
using T = typename TransformedPortalType::ValueType; using T = typename TransformedPortalType::ValueType;
typename OriginalPortalType::ValueType original = originalPortal.Get(index); typename OriginalPortalType::ValueType original = originalPortal.Get(index);
T transformed = transformedPortal.Get(index); T transformed = transformedPortal.Get(index);
VTKM_TEST_ASSERT(test_equal(transformed, MySquare<T>()(original)), "Bad transform value."); VTKM_TEST_ASSERT(test_equal(transformed, MySquare{}(original)), "Bad transform value.");
} }
} }
...@@ -115,20 +114,19 @@ template <typename InputValueType> ...@@ -115,20 +114,19 @@ template <typename InputValueType>
struct TransformTests struct TransformTests
{ {
using OutputValueType = typename vtkm::VecTraits<InputValueType>::ComponentType; using OutputValueType = typename vtkm::VecTraits<InputValueType>::ComponentType;
using FunctorType = MySquare<OutputValueType>;
using TransformHandle = using TransformHandle =
vtkm::cont::ArrayHandleTransform<vtkm::cont::ArrayHandle<InputValueType>, FunctorType>; vtkm::cont::ArrayHandleTransform<vtkm::cont::ArrayHandle<InputValueType>, MySquare>;
using CountingTransformHandle = using CountingTransformHandle =
vtkm::cont::ArrayHandleTransform<vtkm::cont::ArrayHandleCounting<InputValueType>, FunctorType>; vtkm::cont::ArrayHandleTransform<vtkm::cont::ArrayHandleCounting<InputValueType>, MySquare>;
using Device = VTKM_DEFAULT_DEVICE_ADAPTER_TAG; using Device = VTKM_DEFAULT_DEVICE_ADAPTER_TAG;
using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<Device>; using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<Device>;
void operator()() const void operator()() const
{ {
FunctorType functor; MySquare functor;
std::cout << "Test a transform handle with a counting handle as the values" << std::endl; std::cout << "Test a transform handle with a counting handle as the values" << std::endl;
vtkm::cont::ArrayHandleCounting<InputValueType> counting = vtkm::cont::make_ArrayHandleCounting( vtkm::cont::ArrayHandleCounting<InputValueType> counting = vtkm::cont::make_ArrayHandleCounting(
......
...@@ -189,7 +189,7 @@ void GeneralVecCTypeTest(const vtkm::Vec<ComponentType, Size>&) ...@@ -189,7 +189,7 @@ void GeneralVecCTypeTest(const vtkm::Vec<ComponentType, Size>&)
div = aSrc / b; div = aSrc / b;
VTKM_TEST_ASSERT(test_equal(div, correct_div), "Tuples not divided correctly."); VTKM_TEST_ASSERT(test_equal(div, correct_div), "Tuples not divided correctly.");
ComponentType d = vtkm::dot(a, b); ComponentType d = static_cast<ComponentType>(vtkm::dot(a, b));
ComponentType correct_d = 0; ComponentType correct_d = 0;
for (vtkm::IdComponent i = 0; i < Size; ++i) for (vtkm::IdComponent i = 0; i < Size; ++i)
{ {
...@@ -286,7 +286,7 @@ void GeneralVecCConstTypeTest(const vtkm::Vec<ComponentType, Size>&) ...@@ -286,7 +286,7 @@ void GeneralVecCConstTypeTest(const vtkm::Vec<ComponentType, Size>&)
div = aSrc / b; div = aSrc / b;
VTKM_TEST_ASSERT(test_equal(div, correct_div), "Tuples not divided correctly."); VTKM_TEST_ASSERT(test_equal(div, correct_div), "Tuples not divided correctly.");
ComponentType d = vtkm::dot(a, b); ComponentType d = static_cast<ComponentType>(vtkm::dot(a, b));
ComponentType correct_d = 0; ComponentType correct_d = 0;
for (vtkm::IdComponent i = 0; i < Size; ++i) for (vtkm::IdComponent i = 0; i < Size; ++i)
{ {
...@@ -403,7 +403,7 @@ void GeneralVecTypeTest(const vtkm::Vec<ComponentType, Size>&) ...@@ -403,7 +403,7 @@ void GeneralVecTypeTest(const vtkm::Vec<ComponentType, Size>&)
div = a / ComponentType(2); div = a / ComponentType(2);
VTKM_TEST_ASSERT(test_equal(div, b), "Tuple does not divide by Scalar correctly."); VTKM_TEST_ASSERT(test_equal(div, b), "Tuple does not divide by Scalar correctly.");
ComponentType d = vtkm::dot(a, b); ComponentType d = static_cast<ComponentType>(vtkm::dot(a, b));
ComponentType correct_d = 0; ComponentType correct_d = 0;
for (vtkm::IdComponent i = 0; i < T::NUM_COMPONENTS; ++i) for (vtkm::IdComponent i = 0; i < T::NUM_COMPONENTS; ++i)
{ {
...@@ -477,7 +477,7 @@ void TypeTest(const vtkm::Vec<Scalar, 2>&) ...@@ -477,7 +477,7 @@ void TypeTest(const vtkm::Vec<Scalar, 2>&)
VTKM_TEST_ASSERT(test_equal(div, vtkm::make_Vec(1, 2)), VTKM_TEST_ASSERT(test_equal(div, vtkm::make_Vec(1, 2)),
"Vector does not divide by Scalar correctly."); "Vector does not divide by Scalar correctly.");
Scalar d = vtkm::dot(a, b); Scalar d = static_cast<Scalar>(vtkm::dot(a, b));
VTKM_TEST_ASSERT(test_equal(d, Scalar(10)), "dot(Vector2) wrong"); VTKM_TEST_ASSERT(test_equal(d, Scalar(10)), "dot(Vector2) wrong");
VTKM_TEST_ASSERT(!(a < b), "operator< wrong"); VTKM_TEST_ASSERT(!(a < b), "operator< wrong");
...@@ -539,7 +539,7 @@ void TypeTest(const vtkm::Vec<Scalar, 3>&) ...@@ -539,7 +539,7 @@ void TypeTest(const vtkm::Vec<Scalar, 3>&)
div = a / Scalar(2); div = a / Scalar(2);
VTKM_TEST_ASSERT(test_equal(div, b), "Vector does not divide by Scalar correctly."); VTKM_TEST_ASSERT(test_equal(div, b), "Vector does not divide by Scalar correctly.");
Scalar d = vtkm::dot(a, b); Scalar d = static_cast<Scalar>(vtkm::dot(a, b));
VTKM_TEST_ASSERT(test_equal(d, Scalar(28)), "dot(Vector3) wrong"); VTKM_TEST_ASSERT(test_equal(d, Scalar(28)), "dot(Vector3) wrong");
VTKM_TEST_ASSERT(!(a < b), "operator< wrong"); VTKM_TEST_ASSERT(!(a < b), "operator< wrong");
...@@ -601,7 +601,7 @@ void TypeTest(const vtkm::Vec<Scalar, 4>&) ...@@ -601,7 +601,7 @@ void TypeTest(const vtkm::Vec<Scalar, 4>&)
div = a / Scalar(2); div = a / Scalar(2);
VTKM_TEST_ASSERT(test_equal(div, b), "Vector does not divide by Scalar correctly."); VTKM_TEST_ASSERT(test_equal(div, b), "Vector does not divide by Scalar correctly.");
Scalar d = vtkm::dot(a, b); Scalar d = static_cast<Scalar>(vtkm::dot(a, b));
VTKM_TEST_ASSERT(test_equal(d, Scalar(60)), "dot(Vector4) wrong"); VTKM_TEST_ASSERT(test_equal(d, Scalar(60)), "dot(Vector4) wrong");
VTKM_TEST_ASSERT(!(a < b), "operator< wrong"); VTKM_TEST_ASSERT(!(a < b), "operator< wrong");
...@@ -672,6 +672,17 @@ void TypeTest(Scalar) ...@@ -672,6 +672,17 @@ void TypeTest(Scalar)
{ {
VTKM_TEST_FAIL("dot(Scalar) wrong"); VTKM_TEST_FAIL("dot(Scalar) wrong");
} }
//verify we don't roll over
Scalar c = 128;
Scalar d = 32;
auto r = vtkm::dot(c, d);
VTKM_TEST_ASSERT((sizeof(r) >= sizeof(int)),
"dot(Scalar) didn't promote smaller than 32bit types");
if (r != 4096)
{
VTKM_TEST_FAIL("dot(Scalar) wrong");
}
} }
struct TypeTestFunctor struct TypeTestFunctor
......
...@@ -135,14 +135,14 @@ static void TestVecTypeImpl(const typename std::remove_const<T>::type& inVector, ...@@ -135,14 +135,14 @@ static void TestVecTypeImpl(const typename std::remove_const<T>::type& inVector,
VTKM_TEST_ASSERT(test_equal(vectorCopy, inVector), "CopyInto does not work."); VTKM_TEST_ASSERT(test_equal(vectorCopy, inVector), "CopyInto does not work.");
{ {
ComponentType result = 0; auto expected = vtkm::dot(vectorCopy, vectorCopy);
decltype(expected) result = 0;
for (vtkm::IdComponent i = 0; i < NUM_COMPONENTS; i++) for (vtkm::IdComponent i = 0; i < NUM_COMPONENTS; i++)
{ {
ComponentType component = Traits::GetComponent(inVector, i); ComponentType component = Traits::GetComponent(inVector, i);
result = ComponentType(result + (component * component)); result = result + (component * component);
} }
VTKM_TEST_ASSERT(test_equal(result, vtkm::dot(vectorCopy, vectorCopy)), VTKM_TEST_ASSERT(test_equal(result, expected), "Got bad result for dot product");
"Got bad result for dot product");
} }
// This will fail to compile if the tags are wrong. // This will fail to compile if the tags are wrong.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment