Commit e480fd7a authored by Kenneth Moreland's avatar Kenneth Moreland
Browse files

Support copying a Variant to itself

parent d2d9ba33
......@@ -34,7 +34,7 @@ static vtkm::Id g_NonTrivialCount;
// A class that must is not trivial to copy nor construct.
struct NonTrivial
{
vtkm::Id Value = 12345;
vtkm::Id Value;
NonTrivial* Self;
void CheckState() const
......@@ -44,17 +44,19 @@ struct NonTrivial
}
NonTrivial()
: Self(this)
: Value(12345)
, Self(this)
{
this->CheckState();
++g_NonTrivialCount;
}
NonTrivial(const NonTrivial& src)
: Self(this)
: Value(src.Value)
{
this->CheckState();
src.CheckState();
this->Self = this;
this->CheckState();
++g_NonTrivialCount;
}
......@@ -445,9 +447,9 @@ void TestCopyDestroy()
using VariantType = vtkm::exec::internal::Variant<TypePlaceholder<0>,
TypePlaceholder<1>,
CountConstructDestruct,
TypePlaceholder<2>,
TypePlaceholder<3>>;
#ifndef VTKM_USING_GLIBCXX_4
TypePlaceholder<3>,
TypePlaceholder<4>>;
#ifdef VTKM_USE_STD_IS_TRIVIAL
VTKM_STATIC_ASSERT(!std::is_trivially_copyable<VariantType>::value);
#endif // !VTKM_USING_GLIBCXX_4
vtkm::Id count = 0;
......@@ -547,6 +549,19 @@ void TestConstructDestruct()
VTKM_TEST_ASSERT(g_NonTrivialCount == 0);
}
void TestCopySelf()
{
std::cout << "Make sure copying a Variant to itself works" << std::endl;
using VariantType =
vtkm::exec::internal::Variant<TypePlaceholder<0>, NonTrivial, TypePlaceholder<2>>;
VariantType variant{ NonTrivial{} };
VariantType& variantRef = variant;
variant = variantRef;
variant = variant.Get<NonTrivial>();
}
void RunTest()
{
TestSize();
......@@ -557,6 +572,7 @@ void RunTest()
TestCopyDestroy();
TestEmplace();
TestConstructDestruct();
TestCopySelf();
}
} // namespace test_variant
......
......@@ -56,16 +56,50 @@ template <typename UnionType>
using VariantUnionToList =
typename VariantUnionToListImpl<typename std::decay<UnionType>::type>::type;
struct VariantCopyFunctor
struct VariantCopyConstructFunctor
{
template <typename T, typename UnionType>
VTK_M_DEVICE void operator()(const T& src, UnionType& destUnion) const noexcept
{
constexpr vtkm::IdComponent Index = vtkm::ListIndexOf<VariantUnionToList<UnionType>, T>::value;
// If we are using this functor, we can assume the union does not hold a valid type.
new (&VariantUnionGet<Index>(destUnion)) T(src);
}
};
struct VariantCopyFunctor
{
template <typename T, typename UnionType>
VTK_M_DEVICE void operator()(const T& src, UnionType& destUnion) const noexcept
{
constexpr vtkm::IdComponent Index = vtkm::ListIndexOf<VariantUnionToList<UnionType>, T>::value;
// If we are using this functor, we can assume the union holds type T.
this->DoCopy(
src, VariantUnionGet<Index>(destUnion), typename std::is_copy_assignable<T>::type{});
}
template <typename T>
VTK_M_DEVICE void DoCopy(const T& src, T& dest, std::true_type) const noexcept
{
dest = src;
}
template <typename T>
VTK_M_DEVICE void DoCopy(const T& src, T& dest, std::false_type) const noexcept
{
if (&src != &dest)
{
// Do not have an assignment operator, so destroy the old object and create a new one.
dest.~T();
new (&dest) T(src);
}
else
{
// Objects are already the same.
}
}
};
struct VariantDestroyFunctor
{
template <typename T>
......@@ -228,15 +262,22 @@ struct VariantConstructorImpl<vtkm::VTK_M_NAMESPACE::internal::Variant<Ts...>,
VTK_M_DEVICE VariantConstructorImpl(const VariantConstructorImpl& src) noexcept
: VariantStorageImpl<Ts...>(vtkm::internal::NullType{})
{
src.CastAndCall(VariantCopyFunctor{}, this->Storage);
src.CastAndCall(VariantCopyConstructFunctor{}, this->Storage);
this->Index = src.Index;
}
VTK_M_DEVICE VariantConstructorImpl& operator=(const VariantConstructorImpl& src) noexcept
{
this->Reset();
src.CastAndCall(detail::VariantCopyFunctor{}, this->Storage);
this->Index = src.Index;
if (this->GetIndex() == src.GetIndex())
{
src.CastAndCall(detail::VariantCopyFunctor{}, this->Storage);
}
else
{
this->Reset();
src.CastAndCall(detail::VariantCopyConstructFunctor{}, this->Storage);
this->Index = src.Index;
}
return *this;
}
};
......@@ -312,7 +353,14 @@ public:
template <typename T>
VTK_M_DEVICE Variant& operator=(const T& src)
{
this->Emplace<T>(src);
if (this->GetIndex() == this->GetIndexOf<T>())
{
this->Get<T>() = src;
}
else
{
this->Emplace<T>(src);
}
return *this;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment