Commit 3159b376 authored by Allison Vacanti's avatar Allison Vacanti
Browse files

Make Swizzle and ExtractComponent array parameters runtime vars.

parent 19344707
......@@ -32,7 +32,7 @@ namespace cont
namespace internal
{
template <typename PortalType, vtkm::IdComponent Component>
template <typename PortalType>
class VTKM_ALWAYS_EXPORT ArrayPortalExtractComponent
{
public:
......@@ -40,24 +40,24 @@ public:
using Traits = vtkm::VecTraits<VectorType>;
using ValueType = typename Traits::ComponentType;
static constexpr vtkm::IdComponent COMPONENT = Component;
VTKM_EXEC_CONT
ArrayPortalExtractComponent()
: Portal()
, Component(0)
{
}
VTKM_EXEC_CONT
ArrayPortalExtractComponent(const PortalType& portal)
ArrayPortalExtractComponent(const PortalType& portal, vtkm::IdComponent component)
: Portal(portal)
, Component(component)
{
}
// Copy constructor
VTKM_EXEC_CONT ArrayPortalExtractComponent(
const ArrayPortalExtractComponent<PortalType, Component>& src)
: Portal(src.GetPortal())
VTKM_EXEC_CONT ArrayPortalExtractComponent(const ArrayPortalExtractComponent<PortalType>& src)
: Portal(src.Portal)
, Component(src.Component)
{
}
......@@ -67,14 +67,14 @@ public:
VTKM_EXEC_CONT
ValueType Get(vtkm::Id index) const
{
return Traits::GetComponent(this->Portal.Get(index), Component);
return Traits::GetComponent(this->Portal.Get(index), this->Component);
}
VTKM_EXEC_CONT
void Set(vtkm::Id index, const ValueType& value) const
{
VectorType vec = this->Portal.Get(index);
Traits::SetComponent(vec, Component, value);
Traits::SetComponent(vec, this->Component, value);
this->Portal.Set(index, vec);
}
......@@ -83,39 +83,40 @@ public:
private:
PortalType Portal;
vtkm::IdComponent Component;
}; // class ArrayPortalExtractComponent
} // namespace internal
template <typename ArrayHandleType, vtkm::IdComponent Component>
template <typename ArrayHandleType>
class StorageTagExtractComponent
{
static constexpr vtkm::IdComponent COMPONENT = Component;
};
namespace internal
{
template <typename ArrayHandleType, vtkm::IdComponent Component>
template <typename ArrayHandleType>
class Storage<typename vtkm::VecTraits<typename ArrayHandleType::ValueType>::ComponentType,
StorageTagExtractComponent<ArrayHandleType, Component>>
StorageTagExtractComponent<ArrayHandleType>>
{
public:
using PortalType =
ArrayPortalExtractComponent<typename ArrayHandleType::PortalControl, Component>;
using PortalConstType =
ArrayPortalExtractComponent<typename ArrayHandleType::PortalConstControl, Component>;
using PortalType = ArrayPortalExtractComponent<typename ArrayHandleType::PortalControl>;
using PortalConstType = ArrayPortalExtractComponent<typename ArrayHandleType::PortalConstControl>;
using ValueType = typename PortalType::ValueType;
VTKM_CONT
Storage()
: Valid(false)
: Array()
, Component(0)
, Valid(false)
{
}
VTKM_CONT
Storage(const ArrayHandleType& array)
Storage(const ArrayHandleType& array, vtkm::IdComponent component)
: Array(array)
, Component(component)
, Valid(true)
{
}
......@@ -124,14 +125,14 @@ public:
PortalConstType GetPortalConst() const
{
VTKM_ASSERT(this->Valid);
return PortalConstType(this->Array.GetPortalConstControl());
return PortalConstType(this->Array.GetPortalConstControl(), this->Component);
}
VTKM_CONT
PortalType GetPortal()
{
VTKM_ASSERT(this->Valid);
return PortalType(this->Array.GetPortalControl());
return PortalType(this->Array.GetPortalControl(), this->Component);
}
VTKM_CONT
......@@ -169,21 +170,29 @@ public:
return this->Array;
}
VTKM_CONT
vtkm::IdComponent GetComponent() const
{
VTKM_ASSERT(this->Valid);
return this->Component;
}
private:
ArrayHandleType Array;
vtkm::IdComponent Component;
bool Valid;
}; // class Storage
template <typename ArrayHandleType, vtkm::IdComponent Component, typename Device>
template <typename ArrayHandleType, typename Device>
class ArrayTransfer<typename vtkm::VecTraits<typename ArrayHandleType::ValueType>::ComponentType,
StorageTagExtractComponent<ArrayHandleType, Component>,
StorageTagExtractComponent<ArrayHandleType>,
Device>
{
public:
using ValueType = typename vtkm::VecTraits<typename ArrayHandleType::ValueType>::ComponentType;
private:
using StorageTag = StorageTagExtractComponent<ArrayHandleType, Component>;
using StorageTag = StorageTagExtractComponent<ArrayHandleType>;
using StorageType = vtkm::cont::internal::Storage<ValueType, StorageTag>;
using ArrayValueType = typename ArrayHandleType::ValueType;
using ArrayStorageTag = typename ArrayHandleType::StorageTag;
......@@ -195,13 +204,13 @@ public:
using PortalConstControl = typename StorageType::PortalConstType;
using ExecutionTypes = typename ArrayHandleType::template ExecutionTypes<Device>;
using PortalExecution = ArrayPortalExtractComponent<typename ExecutionTypes::Portal, Component>;
using PortalConstExecution =
ArrayPortalExtractComponent<typename ExecutionTypes::PortalConst, Component>;
using PortalExecution = ArrayPortalExtractComponent<typename ExecutionTypes::Portal>;
using PortalConstExecution = ArrayPortalExtractComponent<typename ExecutionTypes::PortalConst>;
VTKM_CONT
ArrayTransfer(StorageType* storage)
: Array(storage->GetArray())
, Component(storage->GetComponent())
{
}
......@@ -211,19 +220,19 @@ public:
VTKM_CONT
PortalConstExecution PrepareForInput(bool vtkmNotUsed(updateData))
{
return PortalConstExecution(this->Array.PrepareForInput(Device()));
return PortalConstExecution(this->Array.PrepareForInput(Device()), this->Component);
}
VTKM_CONT
PortalExecution PrepareForInPlace(bool vtkmNotUsed(updateData))
{
return PortalExecution(this->Array.PrepareForInPlace(Device()));
return PortalExecution(this->Array.PrepareForInPlace(Device()), this->Component);
}
VTKM_CONT
PortalExecution PrepareForOutput(vtkm::Id numberOfValues)
{
return PortalExecution(this->Array.PrepareForOutput(numberOfValues, Device()));
return PortalExecution(this->Array.PrepareForOutput(numberOfValues, Device()), this->Component);
}
VTKM_CONT
......@@ -242,6 +251,7 @@ public:
private:
ArrayHandleType Array;
vtkm::IdComponent Component;
};
}
}
......@@ -263,40 +273,39 @@ namespace cont
/// the index array and reads or writes to the specified component, leave all
/// other components unmodified. This is done on the fly rather than creating a
/// copy of the array.
template <typename ArrayHandleType, vtkm::IdComponent Component>
template <typename ArrayHandleType>
class ArrayHandleExtractComponent
: public vtkm::cont::ArrayHandle<
typename vtkm::VecTraits<typename ArrayHandleType::ValueType>::ComponentType,
StorageTagExtractComponent<ArrayHandleType, Component>>
StorageTagExtractComponent<ArrayHandleType>>
{
public:
static constexpr vtkm::IdComponent COMPONENT = Component;
VTKM_ARRAY_HANDLE_SUBCLASS(
ArrayHandleExtractComponent,
(ArrayHandleExtractComponent<ArrayHandleType, Component>),
(ArrayHandleExtractComponent<ArrayHandleType>),
(vtkm::cont::ArrayHandle<
typename vtkm::VecTraits<typename ArrayHandleType::ValueType>::ComponentType,
StorageTagExtractComponent<ArrayHandleType, Component>>));
StorageTagExtractComponent<ArrayHandleType>>));
protected:
using StorageType = vtkm::cont::internal::Storage<ValueType, StorageTag>;
public:
VTKM_CONT
ArrayHandleExtractComponent(const ArrayHandleType& array)
: Superclass(StorageType(array))
ArrayHandleExtractComponent(const ArrayHandleType& array, vtkm::IdComponent component)
: Superclass(StorageType(array, component))
{
}
};
/// make_ArrayHandleExtractComponent is convenience function to generate an
/// ArrayHandleExtractComponent.
template <vtkm::IdComponent Component, typename ArrayHandleType>
VTKM_CONT ArrayHandleExtractComponent<ArrayHandleType, Component> make_ArrayHandleExtractComponent(
const ArrayHandleType& array)
template <typename ArrayHandleType>
VTKM_CONT ArrayHandleExtractComponent<ArrayHandleType> make_ArrayHandleExtractComponent(
const ArrayHandleType& array,
vtkm::IdComponent component)
{
return ArrayHandleExtractComponent<ArrayHandleType, Component>(array);
return ArrayHandleExtractComponent<ArrayHandleType>(array, component);
}
}
} // namespace vtkm::cont
......
This diff is collapsed.
......@@ -415,7 +415,7 @@ private:
{
vtkm::cont::ArrayHandle<T, StorageTagBasic> a1;
vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>, StorageTagBasic> tmp;
auto a2 = vtkm::cont::make_ArrayHandleExtractComponent<1>(tmp);
auto a2 = vtkm::cont::make_ArrayHandleExtractComponent(tmp, 1);
VTKM_TEST_ASSERT(a1 != a2, "Arrays with different storage type compared equal.");
VTKM_TEST_ASSERT(!(a1 == a2), "Arrays with different storage type compared equal.");
......@@ -434,7 +434,7 @@ private:
{
vtkm::cont::ArrayHandle<T, StorageTagBasic> a1;
vtkm::cont::ArrayHandle<vtkm::Vec<typename OtherType<T>::Type, 3>, StorageTagBasic> tmp;
auto a2 = vtkm::cont::make_ArrayHandleExtractComponent<1>(tmp);
auto a2 = vtkm::cont::make_ArrayHandleExtractComponent(tmp, 1);
VTKM_TEST_ASSERT(a1 != a2, "Arrays with different storage and value type compared equal.");
VTKM_TEST_ASSERT(!(a1 == a2),
......
......@@ -36,8 +36,7 @@ template <typename ValueType>
struct ExtractComponentTests
{
using InputArray = vtkm::cont::ArrayHandle<vtkm::Vec<ValueType, 4>>;
template <vtkm::IdComponent Component>
using ExtractArray = vtkm::cont::ArrayHandleExtractComponent<InputArray, Component>;
using ExtractArray = vtkm::cont::ArrayHandleExtractComponent<InputArray>;
using ReferenceComponentArray = vtkm::cont::ArrayHandleCounting<ValueType>;
using ReferenceCompositeArray =
typename vtkm::cont::ArrayHandleCompositeVectorType<ReferenceComponentArray,
......@@ -70,36 +69,32 @@ struct ExtractComponentTests
return result;
}
template <vtkm::IdComponent Component>
void SanityCheck() const
void SanityCheck(vtkm::IdComponent component) const
{
InputArray composite = this->BuildInputArray();
ExtractArray<Component> extract =
vtkm::cont::make_ArrayHandleExtractComponent<Component>(composite);
ExtractArray extract(composite, component);
VTKM_TEST_ASSERT(composite.GetNumberOfValues() == extract.GetNumberOfValues(),
"Number of values in copied ExtractComponent array does not match input.");
}
template <vtkm::IdComponent Component>
void ReadTestComponentExtraction() const
void ReadTestComponentExtraction(vtkm::IdComponent component) const
{
// Test that the expected values are read from an ExtractComponent array.
InputArray composite = this->BuildInputArray();
ExtractArray<Component> extract =
vtkm::cont::make_ArrayHandleExtractComponent<Component>(composite);
ExtractArray extract(composite, component);
// Test reading the data back in the control env:
this->ValidateReadTestArray<Component>(extract);
this->ValidateReadTestArray(extract, component);
// Copy the extract array in the execution environment to test reading:
vtkm::cont::ArrayHandle<ValueType> execCopy;
Algo::Copy(extract, execCopy);
this->ValidateReadTestArray<Component>(execCopy);
this->ValidateReadTestArray(execCopy, component);
}
template <vtkm::IdComponent Component, typename ArrayHandleType>
void ValidateReadTestArray(ArrayHandleType testArray) const
template <typename ArrayHandleType>
void ValidateReadTestArray(ArrayHandleType testArray, vtkm::IdComponent component) const
{
using RefVectorType = typename ReferenceCompositeArray::ValueType;
using Traits = vtkm::VecTraits<RefVectorType>;
......@@ -113,13 +108,13 @@ struct ExtractComponentTests
for (vtkm::Id i = 0; i < testPortal.GetNumberOfValues(); ++i)
{
VTKM_TEST_ASSERT(
test_equal(testPortal.Get(i), Traits::GetComponent(refPortal.Get(i), Component), 0.),
test_equal(testPortal.Get(i), Traits::GetComponent(refPortal.Get(i), component), 0.),
"Value mismatch in read test.");
}
}
// Doubles the specified component (reading from RefVectorType).
template <typename PortalType, typename RefPortalType, vtkm::IdComponent Component>
template <typename PortalType, typename RefPortalType>
struct WriteTestFunctor : vtkm::exec::FunctorBase
{
using RefVectorType = typename RefPortalType::ValueType;
......@@ -127,63 +122,64 @@ struct ExtractComponentTests
PortalType Portal;
RefPortalType RefPortal;
vtkm::IdComponent Component;
VTKM_CONT
WriteTestFunctor(const PortalType& portal, const RefPortalType& ref)
WriteTestFunctor(const PortalType& portal,
const RefPortalType& ref,
vtkm::IdComponent component)
: Portal(portal)
, RefPortal(ref)
, Component(component)
{
}
VTKM_EXEC_CONT
void operator()(vtkm::Id index) const
{
this->Portal.Set(index, Traits::GetComponent(this->RefPortal.Get(index), Component) * 2);
this->Portal.Set(index,
Traits::GetComponent(this->RefPortal.Get(index), this->Component) * 2);
}
};
template <vtkm::IdComponent Component>
void WriteTestComponentExtraction() const
void WriteTestComponentExtraction(vtkm::IdComponent component) const
{
// Control test:
{
InputArray composite = this->BuildInputArray();
ExtractArray<Component> extract =
vtkm::cont::make_ArrayHandleExtractComponent<Component>(composite);
ExtractArray extract(composite, component);
WriteTestFunctor<typename ExtractArray<Component>::PortalControl,
typename ReferenceCompositeArray::PortalConstControl,
Component>
functor(extract.GetPortalControl(), this->RefComposite.GetPortalConstControl());
WriteTestFunctor<typename ExtractArray::PortalControl,
typename ReferenceCompositeArray::PortalConstControl>
functor(extract.GetPortalControl(), this->RefComposite.GetPortalConstControl(), component);
for (vtkm::Id i = 0; i < extract.GetNumberOfValues(); ++i)
{
functor(i);
}
this->ValidateWriteTestArray<Component>(composite);
this->ValidateWriteTestArray(composite, component);
}
// Exec test:
{
InputArray composite = this->BuildInputArray();
ExtractArray<Component> extract =
vtkm::cont::make_ArrayHandleExtractComponent<Component>(composite);
ExtractArray extract(composite, component);
using Portal = typename ExtractArray<Component>::template ExecutionTypes<DeviceTag>::Portal;
using Portal = typename ExtractArray::template ExecutionTypes<DeviceTag>::Portal;
using RefPortal =
typename ReferenceCompositeArray::template ExecutionTypes<DeviceTag>::PortalConst;
WriteTestFunctor<Portal, RefPortal, Component> functor(
extract.PrepareForInPlace(DeviceTag()), this->RefComposite.PrepareForInput(DeviceTag()));
WriteTestFunctor<Portal, RefPortal> functor(extract.PrepareForInPlace(DeviceTag()),
this->RefComposite.PrepareForInput(DeviceTag()),
component);
Algo::Schedule(functor, extract.GetNumberOfValues());
this->ValidateWriteTestArray<Component>(composite);
this->ValidateWriteTestArray(composite, component);
}
}
template <vtkm::IdComponent Component>
void ValidateWriteTestArray(InputArray testArray) const
void ValidateWriteTestArray(InputArray testArray, vtkm::IdComponent component) const
{
using VectorType = typename ReferenceCompositeArray::ValueType;
using Traits = vtkm::VecTraits<VectorType>;
......@@ -199,28 +195,27 @@ struct ExtractComponentTests
{
auto value = portal.Get(i);
auto refValue = refPortal.Get(i);
Traits::SetComponent(refValue, Component, Traits::GetComponent(refValue, Component) * 2);
Traits::SetComponent(refValue, component, Traits::GetComponent(refValue, component) * 2);
VTKM_TEST_ASSERT(test_equal(refValue, value, 0.), "Value mismatch in write test.");
}
}
template <vtkm::IdComponent Component>
void TestComponent() const
void TestComponent(vtkm::IdComponent component) const
{
this->SanityCheck<Component>();
this->ReadTestComponentExtraction<Component>();
this->WriteTestComponentExtraction<Component>();
this->SanityCheck(component);
this->ReadTestComponentExtraction(component);
this->WriteTestComponentExtraction(component);
}
void operator()()
{
this->ConstructReferenceArray();
this->TestComponent<0>();
this->TestComponent<1>();
this->TestComponent<2>();
this->TestComponent<3>();
this->TestComponent(0);
this->TestComponent(1);
this->TestComponent(2);
this->TestComponent(3);
}
};
......
......@@ -39,8 +39,8 @@ struct SwizzleTests
{
using SwizzleInputArrayType = vtkm::cont::ArrayHandle<vtkm::Vec<ValueType, 4>>;
template <vtkm::IdComponent... ComponentMap>
using SwizzleArrayType = vtkm::cont::ArrayHandleSwizzle<SwizzleInputArrayType, ComponentMap...>;
template <vtkm::IdComponent OutSize>
using SwizzleArrayType = vtkm::cont::ArrayHandleSwizzle<SwizzleInputArrayType, OutSize>;
using ReferenceComponentArrayType = vtkm::cont::ArrayHandleCounting<ValueType>;
using ReferenceArrayType =
......@@ -49,6 +49,9 @@ struct SwizzleTests
ReferenceComponentArrayType,
ReferenceComponentArrayType>::type;
template <vtkm::IdComponent Size>
using MapType = vtkm::Vec<vtkm::IdComponent, Size>;
using DeviceTag = VTKM_DEFAULT_DEVICE_ADAPTER_TAG;
using Algo = vtkm::cont::DeviceAdapterAlgorithm<DeviceTag>;
......@@ -78,62 +81,57 @@ struct SwizzleTests
return result;
}
template <vtkm::IdComponent... ComponentMap>
void SanityCheck() const
template <vtkm::IdComponent OutSize>
void SanityCheck(const MapType<OutSize>& map) const
{
using Swizzle = SwizzleArrayType<ComponentMap...>;
using Swizzle = SwizzleArrayType<OutSize>;
using Traits = typename Swizzle::SwizzleTraits;
VTKM_TEST_ASSERT(Traits::COUNT == vtkm::VecTraits<typename Swizzle::ValueType>::NUM_COMPONENTS,
"Traits::COUNT invalid.");
VTKM_TEST_ASSERT(Traits::OutVecSize ==
vtkm::VecTraits<typename Swizzle::ValueType>::NUM_COMPONENTS,
"Traits::OutVecSize invalid.");
VTKM_TEST_ASSERT(
VTKM_PASS_COMMAS(std::is_same<typename Traits::ComponentType, ValueType>::value),
"Traits::ComponentType invalid.");
VTKM_TEST_ASSERT(
VTKM_PASS_COMMAS(
std::is_same<
typename Traits::OutputType,
vtkm::Vec<ValueType, static_cast<vtkm::IdComponent>(sizeof...(ComponentMap))>>::value),
"Traits::OutputType invalid.");
std::is_same<typename Traits::OutValueType, vtkm::Vec<ValueType, OutSize>>::value),
"Traits::OutValueType invalid.");
SwizzleInputArrayType input = this->BuildSwizzleInputArray();
SwizzleArrayType<ComponentMap...> swizzle =
vtkm::cont::make_ArrayHandleSwizzle<ComponentMap...>(input);
auto swizzle = vtkm::cont::make_ArrayHandleSwizzle(input, map);
VTKM_TEST_ASSERT(input.GetNumberOfValues() == swizzle.GetNumberOfValues(),
"Number of values in copied Swizzle array does not match input.");
}
template <vtkm::IdComponent... ComponentMap>
void ReadTest() const
template <vtkm::IdComponent OutSize>
void ReadTest(const MapType<OutSize>& map) const
{
using Traits = typename SwizzleArrayType<ComponentMap...>::SwizzleTraits;
using Traits = typename SwizzleArrayType<OutSize>::SwizzleTraits;
// Test that the expected values are read from an Swizzle array.
SwizzleInputArrayType input = this->BuildSwizzleInputArray();
SwizzleArrayType<ComponentMap...> swizzle =
vtkm::cont::make_ArrayHandleSwizzle<ComponentMap...>(input);
auto swizzle = vtkm::cont::make_ArrayHandleSwizzle(input, map);
// Test reading the data back in the control env:
this->ValidateReadTest<ComponentMap...>(swizzle);
this->ValidateReadTest(swizzle, map);
// Copy the extract array in the execution environment to test reading:
vtkm::cont::ArrayHandle<typename Traits::OutputType> execCopy;
// Copy the extracted array in the execution environment to test reading:
vtkm::cont::ArrayHandle<typename Traits::OutValueType> execCopy;
Algo::Copy(swizzle, execCopy);
this->ValidateReadTest<ComponentMap...>(execCopy);
this->ValidateReadTest(execCopy, map);
}
template <vtkm::IdComponent... ComponentMap, typename ArrayHandleType>
void ValidateReadTest(ArrayHandleType testArray) const
template <typename ArrayHandleType, vtkm::IdComponent OutSize>
void ValidateReadTest(ArrayHandleType testArray, const MapType<OutSize>& map) const
{
using Traits = typename SwizzleArrayType<ComponentMap...>::SwizzleTraits;
using MapType = typename Traits::RuntimeComponentMapType;
const MapType map = Traits::GenerateRuntimeComponentMap();
using Traits = typename SwizzleArrayType<OutSize>::SwizzleTraits;
using ReferenceVectorType = typename ReferenceArrayType::ValueType;
using SwizzleVectorType = typename Traits::OutputType;
using SwizzleVectorType = typename Traits::OutValueType;
VTKM_TEST_ASSERT(map.size() == vtkm::VecTraits<SwizzleVectorType>::NUM_COMPONENTS,
VTKM_TEST_ASSERT(map.GetNumberOfComponents() ==
vtkm::VecTraits<SwizzleVectorType>::NUM_COMPONENTS,
"Unexpected runtime component map size.");
VTKM_TEST_ASSERT(testArray.GetNumberOfValues() == this->RefArray.GetNumberOfValues(),
"Number of values incorrect in Read test.");
......@@ -147,9 +145,9 @@ struct SwizzleTests
ReferenceVectorType refVec = refPortal.Get(i);
// Manually swizzle the reference vector using the runtime map information:
for (size_t j = 0; j < map.size(); ++j)
for (vtkm::IdComponent j = 0; j < map.GetNumberOfComponents(); ++j)
{
refVecSwizzle[static_cast<vtkm::IdComponent>(j)] = refVec[map[j]];
refVecSwizzle[j] = refVec[map[j]];
}
VTKM_TEST_ASSERT(test_equal(refVecSwizzle, testPortal.Get(i), 0.),
......@@ -173,16 +171,15 @@ struct SwizzleTests
void operator()(vtkm::Id index) const { this->Portal.Set(index, this->Portal.Get(index) * 2.); }
};
template <vtkm::IdComponent... ComponentMap>