Commit c9bcdd01 authored by Kenneth Moreland's avatar Kenneth Moreland
Browse files

Use a union in Variant for safe type punning

Create a `VaraintUnion` that is an actual C++ `union` to store the data
in a `Variant`.

You may be asking yourself, why not just use an `std::aligned_union`
rather than a real union type? That was our first implementation, but
the problem is that the `std::aligned_union` reference needs to be
recast to the actual type. Typically you would do that with
`reinterpret_cast`. However, doing that leads to undefined behavior. The
C++ compiler assumes that 2 pointers of different types point to
different memory (even if it is clear that they are set to the same
address). That means optimizers can remove code because it "knows" that
data in one type cannot affect data in another type. To safely change
the type of an `std::aligned_union`, you really have to do an
`std::memcpy`. This is problematic for types that cannot be trivially
copied. Another problem is that we found that device compilers do not
optimize the memcpy as well as most CPU compilers. Likely, memcpy is
used much less frequently on GPU devices.
parent d7c6ffbd
......@@ -248,50 +248,51 @@ void TestTriviallyCopyable()
VTKM_STATIC_ASSERT(!std::is_trivial<TrivialCopy>::value);
// A variant of trivially constructable things should be trivially constructable
VTKM_STATIC_ASSERT((vtkm::exec::internal::detail::AllTriviallyConstructible<float, int>::value));
VTKM_STATIC_ASSERT((vtkmstd::is_trivially_constructible<
vtkm::exec::internal::detail::VariantUnion<float, int>>::value));
VTKM_STATIC_ASSERT(
(std::is_trivially_constructible<vtkm::exec::internal::Variant<float, int>>::value));
(vtkmstd::is_trivially_constructible<vtkm::exec::internal::Variant<float, int>>::value));
// A variant of trivially copyable things should be trivially copyable
VTKM_STATIC_ASSERT(
(vtkm::exec::internal::detail::AllTriviallyCopyable<float, int, TrivialCopy>::value));
VTKM_STATIC_ASSERT(
(std::is_trivially_copyable<vtkm::exec::internal::Variant<float, int, TrivialCopy>>::value));
VTKM_STATIC_ASSERT((vtkmstd::is_trivially_copyable<
vtkm::exec::internal::detail::VariantUnion<float, int, TrivialCopy>>::value));
VTKM_STATIC_ASSERT((
vtkmstd::is_trivially_copyable<vtkm::exec::internal::Variant<float, int, TrivialCopy>>::value));
// A variant of any non-trivially constructable things is not trivially copyable
VTKM_STATIC_ASSERT(
(!vtkm::exec::internal::detail::AllTriviallyConstructible<NonTrivial, float, int>::value));
VTKM_STATIC_ASSERT(
(!vtkm::exec::internal::detail::AllTriviallyConstructible<float, NonTrivial, int>::value));
VTKM_STATIC_ASSERT(
(!vtkm::exec::internal::detail::AllTriviallyConstructible<float, int, NonTrivial>::value));
VTKM_STATIC_ASSERT((!std::is_trivially_constructible<
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_constructible<
vtkm::exec::internal::detail::VariantUnion<NonTrivial, float, int>>::value));
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_constructible<
vtkm::exec::internal::detail::VariantUnion<float, NonTrivial, int>>::value));
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_constructible<
vtkm::exec::internal::detail::VariantUnion<float, int, NonTrivial>>::value));
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_constructible<
vtkm::exec::internal::Variant<NonTrivial, float, int>>::value));
VTKM_STATIC_ASSERT((!std::is_trivially_constructible<
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_constructible<
vtkm::exec::internal::Variant<float, NonTrivial, int>>::value));
VTKM_STATIC_ASSERT((!std::is_trivially_constructible<
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_constructible<
vtkm::exec::internal::Variant<float, int, NonTrivial>>::value));
// A variant of any non-trivially copyable things is not trivially copyable
VTKM_STATIC_ASSERT(
(!vtkm::exec::internal::detail::AllTriviallyCopyable<NonTrivial, float, int>::value));
VTKM_STATIC_ASSERT(
(!vtkm::exec::internal::detail::AllTriviallyCopyable<float, NonTrivial, int>::value));
VTKM_STATIC_ASSERT(
(!vtkm::exec::internal::detail::AllTriviallyCopyable<float, int, NonTrivial>::value));
VTKM_STATIC_ASSERT(
(!std::is_trivially_copyable<vtkm::exec::internal::Variant<NonTrivial, float, int>>::value));
VTKM_STATIC_ASSERT(
(!std::is_trivially_copyable<vtkm::exec::internal::Variant<float, NonTrivial, int>>::value));
VTKM_STATIC_ASSERT(
(!std::is_trivially_copyable<vtkm::exec::internal::Variant<float, int, NonTrivial>>::value));
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_copyable<
vtkm::exec::internal::detail::VariantUnion<NonTrivial, float, int>>::value));
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_copyable<
vtkm::exec::internal::detail::VariantUnion<float, NonTrivial, int>>::value));
VTKM_STATIC_ASSERT((!vtkmstd::is_trivially_copyable<
vtkm::exec::internal::detail::VariantUnion<float, int, NonTrivial>>::value));
VTKM_STATIC_ASSERT((
!vtkmstd::is_trivially_copyable<vtkm::exec::internal::Variant<NonTrivial, float, int>>::value));
VTKM_STATIC_ASSERT((
!vtkmstd::is_trivially_copyable<vtkm::exec::internal::Variant<float, NonTrivial, int>>::value));
VTKM_STATIC_ASSERT((
!vtkmstd::is_trivially_copyable<vtkm::exec::internal::Variant<float, int, NonTrivial>>::value));
// A variant of trivial things should be trivial
VTKM_STATIC_ASSERT((std::is_trivial<vtkm::exec::internal::Variant<float, int>>::value));
VTKM_STATIC_ASSERT((vtkmstd::is_trivial<vtkm::exec::internal::Variant<float, int>>::value));
VTKM_STATIC_ASSERT(
(!std::is_trivial<vtkm::exec::internal::Variant<float, int, TrivialCopy>>::value));
(!vtkmstd::is_trivial<vtkm::exec::internal::Variant<float, int, TrivialCopy>>::value));
VTKM_STATIC_ASSERT(
(!std::is_trivial<vtkm::exec::internal::Variant<float, int, NonTrivial>>::value));
(!vtkmstd::is_trivial<vtkm::exec::internal::Variant<float, int, NonTrivial>>::value));
#endif // !VTKM_USING_GLIBCXX_4
}
......@@ -434,6 +435,7 @@ struct CountConstructDestruct
++(*this->Count);
}
~CountConstructDestruct() { --(*this->Count); }
CountConstructDestruct& operator=(const CountConstructDestruct&) = delete;
};
void TestCopyDestroy()
......
......@@ -20,8 +20,7 @@
#include <vtkm/Deprecated.h>
#include <vtkm/List.h>
#include <vtkmstd/aligned_union.h>
#include <vtkmstd/is_trivial.h>
#include <vtkm/internal/Assume.h>
namespace vtkm
{
......@@ -37,12 +36,33 @@ class Variant;
namespace detail
{
// --------------------------------------------------------------------------------
// Helper classes for Variant
template <typename UnionType>
struct VariantUnionToListImpl;
template <typename... Ts>
struct VariantUnionToListImpl<detail::VariantUnionTD<Ts...>>
{
using type = vtkm::List<Ts...>;
};
template <typename... Ts>
struct VariantUnionToListImpl<detail::VariantUnionNTD<Ts...>>
{
using type = vtkm::List<Ts...>;
};
template <typename UnionType>
using VariantUnionToList =
typename VariantUnionToListImpl<typename std::decay<UnionType>::type>::type;
struct VariantCopyFunctor
{
template <typename T>
VTK_M_DEVICE void operator()(const T& src, void* destPointer) const noexcept
template <typename T, typename UnionType>
VTK_M_DEVICE void operator()(const T& src, UnionType& destUnion) const noexcept
{
new (destPointer) T(src);
constexpr vtkm::IdComponent Index = vtkm::ListIndexOf<VariantUnionToList<UnionType>, T>::value;
new (&VariantUnionGet<Index>(destUnion)) T(src);
}
};
......@@ -68,115 +88,12 @@ struct VariantCheckType
VTKM_STATIC_ASSERT_MSG(!std::is_pointer<T>::value, "Pointers are not allowed in VTK-m Variant.");
};
template <typename... Ts>
struct AllTriviallyCopyable;
template <>
struct AllTriviallyCopyable<> : std::true_type
{
};
template <typename T0>
struct AllTriviallyCopyable<T0>
: std::integral_constant<bool, (vtkmstd::is_trivially_copyable<T0>::value)>
{
};
template <typename T0, typename T1>
struct AllTriviallyCopyable<T0, T1>
: std::integral_constant<bool,
(vtkmstd::is_trivially_copyable<T0>::value &&
vtkmstd::is_trivially_copyable<T1>::value)>
{
};
template <typename T0, typename T1, typename T2>
struct AllTriviallyCopyable<T0, T1, T2>
: std::integral_constant<bool,
(vtkmstd::is_trivially_copyable<T0>::value &&
vtkmstd::is_trivially_copyable<T1>::value &&
vtkmstd::is_trivially_copyable<T2>::value)>
{
};
template <typename T0, typename T1, typename T2, typename T3>
struct AllTriviallyCopyable<T0, T1, T2, T3>
: std::integral_constant<
bool,
(vtkmstd::is_trivially_copyable<T0>::value && vtkmstd::is_trivially_copyable<T1>::value &&
vtkmstd::is_trivially_copyable<T2>::value && vtkmstd::is_trivially_copyable<T3>::value)>
{
};
template <typename T0, typename T1, typename T2, typename T3, typename T4, typename... Ts>
struct AllTriviallyCopyable<T0, T1, T2, T3, T4, Ts...>
: std::integral_constant<
bool,
(vtkmstd::is_trivially_copyable<T0>::value && vtkmstd::is_trivially_copyable<T1>::value &&
vtkmstd::is_trivially_copyable<T2>::value && vtkmstd::is_trivially_copyable<T3>::value &&
vtkmstd::is_trivially_copyable<T4>::value && AllTriviallyCopyable<Ts...>::value)>
{
};
template <typename VariantType>
struct VariantTriviallyCopyable;
template <typename... Ts>
struct VariantTriviallyCopyable<vtkm::VTK_M_NAMESPACE::internal::Variant<Ts...>>
: AllTriviallyCopyable<Ts...>
{
};
template <typename... Ts>
struct AllTriviallyConstructible;
template <>
struct AllTriviallyConstructible<> : std::true_type
{
};
template <typename T0>
struct AllTriviallyConstructible<T0>
: std::integral_constant<bool, (vtkmstd::is_trivially_constructible<T0>::value)>
{
};
template <typename T0, typename T1>
struct AllTriviallyConstructible<T0, T1>
: std::integral_constant<bool,
(vtkmstd::is_trivially_constructible<T0>::value &&
vtkmstd::is_trivially_constructible<T1>::value)>
{
};
template <typename T0, typename T1, typename T2>
struct AllTriviallyConstructible<T0, T1, T2>
: std::integral_constant<bool,
(vtkmstd::is_trivially_constructible<T0>::value &&
vtkmstd::is_trivially_constructible<T1>::value &&
vtkmstd::is_trivially_constructible<T2>::value)>
{
};
template <typename T0, typename T1, typename T2, typename T3>
struct AllTriviallyConstructible<T0, T1, T2, T3>
: std::integral_constant<bool,
(vtkmstd::is_trivially_constructible<T0>::value &&
vtkmstd::is_trivially_constructible<T1>::value &&
vtkmstd::is_trivially_constructible<T2>::value &&
vtkmstd::is_trivially_constructible<T3>::value)>
{
};
template <typename T0, typename T1, typename T2, typename T3, typename T4, typename... Ts>
struct AllTriviallyConstructible<T0, T1, T2, T3, T4, Ts...>
: std::integral_constant<bool,
(vtkmstd::is_trivially_constructible<T0>::value &&
vtkmstd::is_trivially_constructible<T1>::value &&
vtkmstd::is_trivially_constructible<T2>::value &&
vtkmstd::is_trivially_constructible<T3>::value &&
vtkmstd::is_trivially_constructible<T4>::value &&
AllTriviallyConstructible<Ts...>::value)>
: vtkmstd::is_trivially_copyable<VariantUnion<Ts...>>
{
};
......@@ -185,26 +102,28 @@ struct VariantTriviallyConstructible;
template <typename... Ts>
struct VariantTriviallyConstructible<vtkm::VTK_M_NAMESPACE::internal::Variant<Ts...>>
: AllTriviallyConstructible<Ts...>
: vtkmstd::is_trivially_constructible<VariantUnion<Ts...>>
{
};
// --------------------------------------------------------------------------------
// Variant superclass that defines its storage
template <typename... Ts>
struct VariantStorageImpl
{
typename vtkmstd::aligned_union<0, Ts...>::type Storage;
VariantUnion<Ts...> Storage;
vtkm::IdComponent Index;
template <vtkm::IdComponent Index>
using TypeAt = typename vtkm::ListAt<vtkm::List<Ts...>, Index>;
VariantStorageImpl() = default;
VTK_M_DEVICE void* GetPointer() { return reinterpret_cast<void*>(&this->Storage); }
VTK_M_DEVICE const void* GetPointer() const
VTK_M_DEVICE VariantStorageImpl(vtkm::internal::NullType dummy)
: Storage({ dummy })
{
return reinterpret_cast<const void*>(&this->Storage);
}
template <vtkm::IdComponent Index>
using TypeAt = typename vtkm::ListAt<vtkm::List<Ts...>, Index>;
VTK_M_DEVICE vtkm::IdComponent GetIndex() const noexcept { return this->Index; }
VTK_M_DEVICE bool IsValid() const noexcept
{
......@@ -226,12 +145,11 @@ struct VariantStorageImpl
-> decltype(f(std::declval<const TypeAt<0>&>(), args...))
{
VTKM_ASSERT(this->IsValid());
return detail::VariantCastAndCallImpl<decltype(f(std::declval<const TypeAt<0>&>(), args...))>(
brigand::list<Ts...>{},
this->GetIndex(),
std::forward<Functor>(f),
this->GetPointer(),
std::forward<Args>(args)...);
return detail::VariantCastAndCallImpl(vtkm::ListSize<vtkm::List<Ts...>>{},
this->GetIndex(),
std::forward<Functor>(f),
this->Storage,
std::forward<Args>(args)...);
}
template <typename Functor, typename... Args>
......@@ -240,15 +158,17 @@ struct VariantStorageImpl
-> decltype(f(std::declval<TypeAt<0>&>(), args...))
{
VTKM_ASSERT(this->IsValid());
return detail::VariantCastAndCallImpl<decltype(f(std::declval<TypeAt<0>&>(), args...))>(
brigand::list<Ts...>{},
this->GetIndex(),
std::forward<Functor>(f),
this->GetPointer(),
std::forward<Args>(args)...);
return detail::VariantCastAndCallImpl(vtkm::ListSize<vtkm::List<Ts...>>{},
this->GetIndex(),
std::forward<Functor>(f),
this->Storage,
std::forward<Args>(args)...);
}
};
// --------------------------------------------------------------------------------
// Variant superclass that helps preserve trivially copyable and trivially constructable
// properties where possible.
template <typename VariantType,
typename TriviallyConstructible =
typename VariantTriviallyConstructible<VariantType>::type,
......@@ -277,7 +197,11 @@ struct VariantConstructorImpl<vtkm::VTK_M_NAMESPACE::internal::Variant<Ts...>,
std::false_type,
std::true_type> : VariantStorageImpl<Ts...>
{
VTK_M_DEVICE VariantConstructorImpl() { this->Index = -1; }
VTK_M_DEVICE VariantConstructorImpl()
: VariantStorageImpl<Ts...>(vtkm::internal::NullType{})
{
this->Index = -1;
}
// Any trivially copyable class is trivially destructable.
~VariantConstructorImpl() = default;
......@@ -294,19 +218,24 @@ struct VariantConstructorImpl<vtkm::VTK_M_NAMESPACE::internal::Variant<Ts...>,
construct_type,
std::false_type> : VariantStorageImpl<Ts...>
{
VTK_M_DEVICE VariantConstructorImpl() { this->Index = -1; }
VTK_M_DEVICE VariantConstructorImpl()
: VariantStorageImpl<Ts...>(vtkm::internal::NullType{})
{
this->Index = -1;
}
VTK_M_DEVICE ~VariantConstructorImpl() { this->Reset(); }
VTK_M_DEVICE VariantConstructorImpl(const VariantConstructorImpl& src) noexcept
: VariantStorageImpl<Ts...>(vtkm::internal::NullType{})
{
src.CastAndCall(VariantCopyFunctor{}, this->GetPointer());
src.CastAndCall(VariantCopyFunctor{}, this->Storage);
this->Index = src.Index;
}
VTK_M_DEVICE VariantConstructorImpl& operator=(const VariantConstructorImpl& src) noexcept
{
this->Reset();
src.CastAndCall(detail::VariantCopyFunctor{}, this->GetPointer());
src.CastAndCall(detail::VariantCopyFunctor{}, this->Storage);
this->Index = src.Index;
return *this;
}
......@@ -323,20 +252,6 @@ class Variant : detail::VariantConstructorImpl<Variant<Ts...>>
using CheckTypes = vtkm::List<detail::VariantCheckType<Ts>...>;
public:
/// Returns the index of the type of object this variant is storing. If no object is currently
/// stored (i.e. the `Variant` is invalid), an invalid is returned.
///
VTK_M_DEVICE vtkm::IdComponent GetIndex() const noexcept { return this->Superclass::GetIndex(); }
/// Returns true if this `Variant` is storing an object from one of the types in the template
/// list, false otherwise.
///
/// Note that if this `Variant` was not initialized with an object, the result of `IsValid`
/// is undefined. The `Variant` could report itself as validly containing an object that
/// is trivially constructed.
///
VTK_M_DEVICE bool IsValid() const noexcept { return this->Superclass::IsValid(); }
/// Type that converts to a std::integral_constant containing the index of the given type (or
/// -1 if that type is not in the list).
template <typename T>
......@@ -353,12 +268,29 @@ public:
/// Type that converts to the type at the given index.
///
template <vtkm::IdComponent Index>
using TypeAt = typename Superclass::template TypeAt<Index>;
using TypeAt = typename vtkm::ListAt<vtkm::List<Ts...>, Index>;
/// The number of types representable by this Variant.
///
static constexpr vtkm::IdComponent NumberOfTypes = vtkm::IdComponent{ sizeof...(Ts) };
/// Returns the index of the type of object this variant is storing. If no object is currently
/// stored (i.e. the `Variant` is invalid), an invalid is returned.
///
VTK_M_DEVICE vtkm::IdComponent GetIndex() const noexcept { return this->Index; }
/// Returns true if this `Variant` is storing an object from one of the types in the template
/// list, false otherwise.
///
/// Note that if this `Variant` was not initialized with an object, the result of `IsValid`
/// is undefined. The `Variant` could report itself as validly containing an object that
/// is trivially constructed.
///
VTK_M_DEVICE bool IsValid() const noexcept
{
return (this->Index >= 0) && (this->Index < NumberOfTypes);
}
Variant() = default;
~Variant() = default;
Variant(const Variant&) = default;
......@@ -373,19 +305,15 @@ public:
// Might be a way to use an enable_if to enforce a proper type.
VTKM_STATIC_ASSERT_MSG(index >= 0, "Attempting to put invalid type into a Variant");
new (this->GetPointer()) T(src);
this->Index = index;
new (&this->Get<index>()) T(src);
}
template <typename T>
VTK_M_DEVICE Variant(const T&& src) noexcept
VTK_M_DEVICE Variant& operator=(const T& src)
{
constexpr vtkm::IdComponent index = IndexOf<T>::value;
// Might be a way to use an enable_if to enforce a proper type.
VTKM_STATIC_ASSERT_MSG(index >= 0, "Attempting to put invalid type into a Variant");
new (this->GetPointer()) T(std::move(src));
this->Index = index;
this->Emplace<T>(src);
return *this;
}
template <typename T, typename... Args>
......@@ -425,18 +353,16 @@ private:
VTK_M_DEVICE T& EmplaceImpl(Args&&... args)
{
this->Reset();
T* value = new (this->GetPointer()) T{ args... };
this->Index = I;
return *value;
return *(new (&this->Get<I>()) T{ args... });
}
template <typename T, vtkm::IdComponent I, typename U, typename... Args>
VTK_M_DEVICE T& EmplaceImpl(std::initializer_list<U> il, Args&&... args)
{
this->Reset();
T* value = new (this->GetPointer()) T(il, args...);
this->Index = I;
return *value;
return *(new (&this->Get<I>()) T(il, args...));
}
public:
......@@ -448,14 +374,14 @@ public:
VTK_M_DEVICE TypeAt<I>& Get() noexcept
{
VTKM_ASSERT(I == this->GetIndex());
return *reinterpret_cast<TypeAt<I>*>(this->GetPointer());
return detail::VariantUnionGet<I>(this->Storage);
}
template <vtkm::IdComponent I>
VTK_M_DEVICE const TypeAt<I>& Get() const noexcept
{
VTKM_ASSERT(I == this->GetIndex());
return *reinterpret_cast<const TypeAt<I>*>(this->GetPointer());
return detail::VariantUnionGet<I>(this->Storage);
}
//@}
......@@ -467,14 +393,14 @@ public:
VTK_M_DEVICE T& Get() noexcept
{
VTKM_ASSERT(this->GetIndexOf<T>() == this->GetIndex());
return *reinterpret_cast<T*>(this->GetPointer());
return detail::VariantUnionGet<IndexOf<T>::value>(this->Storage);
}
template <typename T>
VTK_M_DEVICE const T& Get() const noexcept
{
VTKM_ASSERT(this->GetIndexOf<T>() == this->GetIndex());
return *reinterpret_cast<const T*>(this->GetPointer());
return detail::VariantUnionGet<IndexOf<T>::value>(this->Storage);
}
//@}
......@@ -490,7 +416,13 @@ public:
noexcept(noexcept(f(std::declval<const TypeAt<0>&>(), args...)))
-> decltype(f(std::declval<const TypeAt<0>&>(), args...))
{
return this->Superclass::CastAndCall(std::forward<Functor>(f), std::forward<Args>(args)...);
VTKM_ASSERT(this->IsValid());
return detail::VariantCastAndCallImpl(
std::integral_constant<vtkm::IdComponent, NumberOfTypes>{},
this->GetIndex(),
std::forward<Functor>(f),
this->Storage,
std::forward<Args>(args)...);
}
template <typename Functor, typename... Args>
......@@ -498,13 +430,26 @@ public:
noexcept(f(std::declval<const TypeAt<0>&>(), args...)))
-> decltype(f(std::declval<TypeAt<0>&>(), args...))
{
return this->Superclass::CastAndCall(std::forward<Functor>(f), std::forward<Args>(args)...);
VTKM_ASSERT(this->IsValid());
return detail::VariantCastAndCallImpl(
std::integral_constant<vtkm::IdComponent, NumberOfTypes>{},
this->GetIndex(),
std::forward<Functor>(f),
this->Storage,
std::forward<Args>(args)...);
}
/// Destroys any object the Variant is holding and sets the Variant to an invalid state. This
/// method is not thread safe.
///
VTK_M_DEVICE void Reset() noexcept { this->Superclass::Reset(); }
VTK_M_DEVICE void Reset() noexcept
{
if (this->IsValid())
{
this->CastAndCall(detail::VariantDestroyFunctor{});
this->Index = -1;
}
}
};
/// \brief Convert a ListTag to a Variant.
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment