Commit 1e28a9e3 authored by Robert Maynard's avatar Robert Maynard Committed by Kitware Robot

Merge topic 'support_pointers_as_input_to_dispatcher_invoke'

64958b01 VTK-m now supports passing pointers when invoking worklets.
c631dccf Invocation parameters are now non const and can be 'modified'
Acked-by: Kitware Robot's avatarKitware Robot <kwrobot@kitware.com>
Acked-by: Kenneth Moreland's avatarKenneth Moreland <kmorel@sandia.gov>
Merge-request: !1303
parents 9238cedc 64958b01
# VTK-m now supports dispatcher parameters being pointers
Previously it was only possible to pass values to a dispatcher when
you wanted to invoke a VTK-m worklet. This caused problems when it came
to designing new types that used inheritance as the types couldn't be
past as the base type to the dispatcher. To fix this issue we now
support invoking worklets with pointers as seen below.
```cpp
vtkm::cont::ArrayHandle<T> input;
//fill input
vtkm::cont::ArrayHandle<T> output;
vtkm::worklet::DispatcherMapField<WorkletType> dispatcher;
dispatcher(&input, output);
dispatcher(input, &output);
dispatcher(&input, &output);
```
......@@ -120,7 +120,8 @@ public:
template <typename T>
struct ArrayHandleCheck
{
using type = typename std::is_base_of<::vtkm::cont::internal::ArrayHandleBase, T>::type;
using U = typename std::remove_pointer<T>::type;
using type = typename std::is_base_of<::vtkm::cont::internal::ArrayHandleBase, U>::type;
};
#define VTKM_IS_ARRAY_HANDLE(T) \
......
......@@ -87,7 +87,8 @@ namespace internal
template <typename T>
struct CellSetCheck
{
using type = typename std::is_base_of<vtkm::cont::CellSet, T>;
using U = typename std::remove_pointer<T>::type;
using type = typename std::is_base_of<vtkm::cont::CellSet, U>;
};
#define VTKM_IS_CELL_SET(T) VTKM_STATIC_ASSERT(::vtkm::cont::internal::CellSetCheck<T>::type::value)
......
......@@ -53,7 +53,7 @@ struct Transport<vtkm::cont::arg::TransportTagArrayInOut, ContObjectType, Device
using ExecObjectType = decltype(std::declval<ContObjectType>().PrepareForInPlace(Device()));
template <typename InputDomainType>
VTKM_CONT ExecObjectType operator()(ContObjectType object,
VTKM_CONT ExecObjectType operator()(ContObjectType& object,
const InputDomainType& vtkmNotUsed(inputDomain),
vtkm::Id vtkmNotUsed(inputRange),
vtkm::Id outputRange) const
......
......@@ -53,7 +53,7 @@ struct Transport<vtkm::cont::arg::TransportTagArrayOut, ContObjectType, Device>
decltype(std::declval<ContObjectType>().PrepareForOutput(vtkm::Id{}, Device()));
template <typename InputDomainType>
VTKM_CONT ExecObjectType operator()(ContObjectType object,
VTKM_CONT ExecObjectType operator()(ContObjectType& object,
const InputDomainType& vtkmNotUsed(inputDomain),
vtkm::Id vtkmNotUsed(inputRange),
vtkm::Id outputRange) const
......
......@@ -56,10 +56,11 @@ struct Transport<vtkm::cont::arg::TransportTagAtomicArray,
using ExecObjectType = vtkm::exec::AtomicArray<T, Device>;
template <typename InputDomainType>
VTKM_CONT ExecObjectType operator()(vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic> array,
const InputDomainType&,
vtkm::Id,
vtkm::Id) const
VTKM_CONT ExecObjectType
operator()(vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>& array,
const InputDomainType&,
vtkm::Id,
vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed
// array might not have the same size depending on how the user is using
......
......@@ -56,7 +56,7 @@ struct Transport<vtkm::cont::arg::TransportTagExecObject, ContObjectType, Device
using ExecObjectType = decltype(std::declval<ContObjectType>().PrepareForExecution(Device()));
template <typename InputDomainType>
VTKM_CONT ExecObjectType
operator()(const ContObjectType& object, const InputDomainType&, vtkm::Id, vtkm::Id) const
operator()(ContObjectType& object, const InputDomainType&, vtkm::Id, vtkm::Id) const
{
return object.PrepareForExecution(Device());
}
......
......@@ -61,7 +61,7 @@ struct Transport<vtkm::cont::arg::TransportTagWholeArrayIn, ContObjectType, Devi
template <typename InputDomainType>
VTKM_CONT ExecObjectType
operator()(ContObjectType array, const InputDomainType&, vtkm::Id, vtkm::Id) const
operator()(ContObjectType& array, const InputDomainType&, vtkm::Id, vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed
// array might not have the same size depending on how the user is using
......
......@@ -63,7 +63,7 @@ struct Transport<vtkm::cont::arg::TransportTagWholeArrayInOut, ContObjectType, D
template <typename InputDomainType>
VTKM_CONT ExecObjectType
operator()(ContObjectType array, const InputDomainType&, vtkm::Id, vtkm::Id) const
operator()(ContObjectType& array, const InputDomainType&, vtkm::Id, vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed
// array might not have the same size depending on how the user is using
......
......@@ -63,7 +63,7 @@ struct Transport<vtkm::cont::arg::TransportTagWholeArrayOut, ContObjectType, Dev
template <typename InputDomainType>
VTKM_CONT ExecObjectType
operator()(ContObjectType array, const InputDomainType&, vtkm::Id, vtkm::Id) const
operator()(ContObjectType& array, const InputDomainType&, vtkm::Id, vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed
// array might not have the same size depending on how the user is using
......
......@@ -23,7 +23,7 @@ set(unit_tests
UnitTestCudaArrayHandleFancy.cu
UnitTestCudaArrayHandleVirtualCoordinates.cu
UnitTestCudaCellLocatorTwoLevelUniformGrid.cu
UnitTestCudaComputeRange.cu
#UnitTestCudaComputeRange.cu
UnitTestCudaColorTable.cu
UnitTestCudaDataSetExplicit.cu
UnitTestCudaDataSetSingleType.cu
......
......@@ -600,7 +600,7 @@ public:
///
template <typename Transform>
VTKM_CONT typename StaticTransformType<Transform>::type StaticTransformCont(
const Transform& transform) const
const Transform& transform)
{
typename StaticTransformType<Transform>::type newFuncInterface;
detail::DoStaticTransformCont(transform, this->Parameters, newFuncInterface.Parameters);
......@@ -608,7 +608,7 @@ public:
}
template <typename Transform>
VTKM_EXEC typename StaticTransformType<Transform>::type StaticTransformExec(
const Transform& transform) const
const Transform& transform)
{
typename StaticTransformType<Transform>::type newFuncInterface;
detail::DoStaticTransformExec(transform, this->Parameters, newFuncInterface.Parameters);
......
This diff is collapsed.
......@@ -329,7 +329,7 @@ template <typename Transform,
$template_params(num_params,0,'Transformed')>
VTKM_$(environment.upper()) void DoStaticTransform$(environment)(
const Transform& transform,
const ParameterContainer<$signature(num_params,ptype(0,'Original'),'Original')>& originalParameters,
ParameterContainer<$signature(num_params,ptype(0,'Original'),'Original')>& originalParameters,
ParameterContainer<$signature(num_params,ptype(0,'Transformed'),'Transformed')>& transformedParameters)
{
$if(num_params < 1)\
......
......@@ -265,7 +265,7 @@ struct Invocation
/// over to CUDA it gets properly copied to the device. While we want to
/// hold by reference to reduce the number of copies, it is not possible
/// currently.
const ParameterInterface Parameters;
ParameterInterface Parameters;
OutputToInputMapType OutputToInputMap;
VisitArrayType VisitArray;
......
......@@ -63,7 +63,7 @@ public:
}
template <typename Invocation>
VTKM_CONT void DoInvoke(const Invocation& invocation) const
VTKM_CONT void DoInvoke(Invocation& invocation) const
{
// This is the type for the input domain
using InputDomainType = typename Invocation::InputDomainType;
......@@ -76,7 +76,7 @@ public:
// a DynamicArrayHandle that gets cast to one). The size of the domain
// (number of threads/worklet instances) is equal to the size of the
// array.
vtkm::Id numInstances = inputDomain.GetNumberOfValues();
auto numInstances = internal::scheduling_range(inputDomain);
// A MapField is a pretty straightforward dispatch. Once we know the number
// of invocations, the superclass can take care of the rest.
......
......@@ -64,10 +64,11 @@ public:
}
template <typename Invocation>
VTKM_CONT void DoInvoke(const Invocation& invocation) const
VTKM_CONT void DoInvoke(Invocation& invocation) const
{
// This is the type for the input domain
using InputDomainType = typename Invocation::InputDomainType;
using SchedulingRangeType = typename WorkletType::ToTopologyType;
// If you get a compile error on this line, then you have tried to use
// something that is not a vtkm::cont::CellSet as the input domain to a
......@@ -76,12 +77,12 @@ public:
// We can pull the input domain parameter (the data specifying the input
// domain) from the invocation object.
const InputDomainType& inputDomain = invocation.GetInputDomain();
const auto& inputDomain = invocation.GetInputDomain();
// Now that we have the input domain, we can extract the range of the
// scheduling and call BadicInvoke.
this->BasicInvoke(
invocation, inputDomain.GetSchedulingRange(typename WorkletType::ToTopologyType()), Device());
invocation, internal::scheduling_range(inputDomain, SchedulingRangeType{}), Device());
}
};
}
......
......@@ -65,7 +65,7 @@ public:
}
template <typename Invocation>
void DoInvoke(const Invocation& invocation) const
void DoInvoke(Invocation& invocation) const
{
// This is the type for the input domain
using InputDomainType = typename Invocation::InputDomainType;
......@@ -78,7 +78,7 @@ public:
// We can pull the input domain parameter (the data specifying the input
// domain) from the invocation object.
const InputDomainType& inputDomain = invocation.GetInputDomain();
auto inputRange = inputDomain.GetSchedulingRange(vtkm::TopologyElementTagPoint());
auto inputRange = internal::scheduling_range(inputDomain, vtkm::TopologyElementTagPoint{});
// This is pretty straightforward dispatch. Once we know the number
// of invocations, the superclass can take care of the rest.
......
......@@ -65,7 +65,7 @@ public:
}
template <typename Invocation>
void DoInvoke(const Invocation& invocation) const
void DoInvoke(Invocation& invocation) const
{
// This is the type for the input domain
using InputDomainType = typename Invocation::InputDomainType;
......
......@@ -200,7 +200,7 @@ public:
void SetNumberOfBlocks(vtkm::Id numberOfBlocks) { NumberOfBlocks = numberOfBlocks; }
template <typename Invocation, typename DeviceAdapter>
VTKM_CONT void BasicInvoke(const Invocation& invocation,
VTKM_CONT void BasicInvoke(Invocation& invocation,
vtkm::Id numInstances,
vtkm::Id globalIndexOffset,
DeviceAdapter device) const
......@@ -213,7 +213,7 @@ public:
}
template <typename Invocation>
VTKM_CONT void DoInvoke(const Invocation& invocation) const
VTKM_CONT void DoInvoke(Invocation& invocation) const
{
// This is the type for the input domain
using InputDomainType = typename Invocation::InputDomainType;
......@@ -226,7 +226,7 @@ public:
// a DynamicArrayHandle that gets cast to one). The size of the domain
// (number of threads/worklet instances) is equal to the size of the
// array.
vtkm::Id fullSize = inputDomain.GetNumberOfValues();
vtkm::Id fullSize = internal::scheduling_range(inputDomain);
vtkm::Id blockSize = fullSize / NumberOfBlocks;
if (fullSize % NumberOfBlocks != 0)
blockSize += 1;
......@@ -259,7 +259,7 @@ public:
// Loop over parameters again to sync results for this block into control array
using ParameterInterfaceType2 = typename ChangedType::ParameterInterface;
const ParameterInterfaceType2& parameters2 = changedParams.Parameters;
ParameterInterfaceType2& parameters2 = changedParams.Parameters;
parameters2.StaticTransformCont(TransferFunctorType());
}
}
......@@ -269,14 +269,14 @@ private:
typename InputRangeType,
typename OutputRangeType,
typename DeviceAdapter>
VTKM_CONT void InvokeTransportParameters(const Invocation& invocation,
VTKM_CONT void InvokeTransportParameters(Invocation& invocation,
const InputRangeType& inputRange,
const InputRangeType& globalIndexOffset,
const OutputRangeType& outputRange,
DeviceAdapter device) const
{
using ParameterInterfaceType = typename Invocation::ParameterInterface;
const ParameterInterfaceType& parameters = invocation.Parameters;
ParameterInterfaceType& parameters = invocation.Parameters;
using TransportFunctorType = vtkm::worklet::internal::detail::DispatcherBaseTransportFunctor<
typename Invocation::ControlInterface,
......
......@@ -57,6 +57,34 @@ namespace worklet
{
namespace internal
{
template <typename Domain>
inline auto scheduling_range(const Domain& inputDomain) -> decltype(inputDomain.GetNumberOfValues())
{
return inputDomain.GetNumberOfValues();
}
template <typename Domain>
inline auto scheduling_range(const Domain* const inputDomain)
-> decltype(inputDomain->GetNumberOfValues())
{
return inputDomain->GetNumberOfValues();
}
template <typename Domain, typename SchedulingRangeType>
inline auto scheduling_range(const Domain& inputDomain, SchedulingRangeType type)
-> decltype(inputDomain.GetSchedulingRange(type))
{
return inputDomain.GetSchedulingRange(type);
}
template <typename Domain, typename SchedulingRangeType>
inline auto scheduling_range(const Domain* const inputDomain, SchedulingRangeType type)
-> decltype(inputDomain->GetSchedulingRange(type))
{
return inputDomain->GetSchedulingRange(type);
}
namespace detail
{
......@@ -76,6 +104,47 @@ inline void PrintFailureMessage(int index)
throw vtkm::cont::ErrorBadType(message.str());
}
inline void PrintNullPtrMessage(int index, int mode)
{
std::stringstream message;
if (mode == 0)
{
message << "Encountered nullptr for parameter " << index;
}
else
{
message << "Encountered nullptr for " << index << " from last parameter ";
}
message << " when calling Invoke on a dispatcher.";
throw vtkm::cont::ErrorBadValue(message.str());
}
template <typename T>
inline void not_nullptr(T* ptr, int index, int mode = 0)
{
if (!ptr)
{
PrintNullPtrMessage(index, mode);
}
}
template <typename T>
inline void not_nullptr(T&&, int, int mode = 0)
{
(void)mode;
}
template <typename T>
inline T& as_ref(T* ptr)
{
return *ptr;
}
template <typename T>
inline T&& as_ref(T&& t)
{
return std::forward<T>(t);
}
template <typename T, bool noError>
struct ReportTypeOnError;
template <typename T>
......@@ -90,10 +159,16 @@ struct ReportValueOnError<Value, true> : std::true_type
{
};
template <typename T>
struct remove_pointer_and_decay : std::remove_pointer<typename std::decay<T>::type>
{
};
// Is designed as a brigand fold operation.
template <typename T, typename State>
template <typename Type, typename State>
struct DetermineIfHasDynamicParameter
{
using T = typename std::remove_pointer<Type>::type;
using DynamicTag = typename vtkm::cont::internal::DynamicTransformTraits<T>::DynamicTag;
using isDynamic =
typename std::is_same<DynamicTag, vtkm::cont::internal::DynamicTransformTagCastAndCall>::type;
......@@ -106,7 +181,7 @@ struct DetermineIfHasDynamicParameter
template <typename WorkletType>
struct DetermineHasCorrectParameters
{
template <typename T, typename State, typename SigTypes>
template <typename Type, typename State, typename SigTypes>
struct Functor
{
//T is the type of the Param at the current index
......@@ -114,6 +189,7 @@ struct DetermineHasCorrectParameters
using ControlSignatureTag = typename brigand::at_c<SigTypes, State::value>;
using TypeCheckTag = typename ControlSignatureTag::TypeCheckTag;
using T = typename std::remove_pointer<Type>::type;
static constexpr bool isCorrect = vtkm::cont::arg::TypeCheck<TypeCheckTag, T>::value;
// If you get an error on the line below, that means that your code has called the
......@@ -224,27 +300,46 @@ struct DispatcherBaseTransportFunctor
{
}
template <typename ControlParameter, vtkm::IdComponent Index>
struct ReturnType
{
using TransportTag =
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
using TransportType =
typename vtkm::cont::arg::Transport<TransportTag, ControlParameter, Device>;
using T = typename remove_pointer_and_decay<ControlParameter>::type;
using TransportType = typename vtkm::cont::arg::Transport<TransportTag, T, Device>;
using type = typename TransportType::ExecObjectType;
};
// template<typename ControlParameter, vtkm::IdComponent Index>
// VTKM_CONT typename ReturnType<ControlParameter, Index>::type operator()(
// ControlParameter const& invokeData,
// vtkm::internal::IndexTag<Index>) const
// {
// using TransportTag =
// typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
// using T = typename remove_pointer_and_decay<ControlParameter>::type;
// vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
// return transport(invokeData, as_ref(this->InputDomain), this->InputRange, this->OutputRange);
// }
template <typename ControlParameter, vtkm::IdComponent Index>
VTKM_CONT typename ReturnType<ControlParameter, Index>::type operator()(
const ControlParameter& invokeData,
ControlParameter&& invokeData,
vtkm::internal::IndexTag<Index>) const
{
using TransportTag =
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
vtkm::cont::arg::Transport<TransportTag, ControlParameter, Device> transport;
return transport(invokeData, this->InputDomain, this->InputRange, this->OutputRange);
using T = typename remove_pointer_and_decay<ControlParameter>::type;
vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
not_nullptr(invokeData, Index);
return transport(
as_ref(invokeData), as_ref(this->InputDomain), this->InputRange, this->OutputRange);
}
private:
void operator=(const DispatcherBaseTransportFunctor&) = delete;
};
......@@ -306,7 +401,8 @@ inline void convert_arg(vtkm::cont::internal::DynamicTransformTagCastAndCall,
using tag_check = typename brigand::at_c<ContParams, 0>::TypeCheckTag;
using popped_sig = brigand::pop_front<ContParams>;
vtkm::cont::CastAndCall(t,
not_nullptr(t, LeftToProcess, 1);
vtkm::cont::CastAndCall(as_ref(t),
convert_arg_wrapper<LeftToProcess, tag_check>(),
trampoline,
popped_sig(),
......@@ -319,8 +415,8 @@ struct for_each_dynamic_arg
template <typename Trampoline, typename ContParams, typename T, typename... Args>
void operator()(const Trampoline& trampoline, ContParams&& sig, T&& t, Args&&... args) const
{
//Determine that state of T
using Type = typename std::decay<T>::type;
//Determine that state of T when it is either a `cons&` or a `* const&`
using Type = typename std::remove_pointer<typename std::decay<T>::type>::type;
using tag = typename vtkm::cont::internal::DynamicTransformTraits<Type>::DynamicTag;
//convert the first item to a known type
convert_arg<LeftToProcess>(
......@@ -456,7 +552,8 @@ private:
// argument) and the ControlSignature tags (in the ControlInterface type).
using ContParamsInfo =
vtkm::internal::detail::FunctionSigInfo<typename WorkletType::ControlSignature>;
detail::deduce(*this, typename ContParamsInfo::Parameters(), std::forward<Args>(args)...);
typename ContParamsInfo::Parameters parameters;
detail::deduce(*this, parameters, std::forward<Args>(args)...);
}
template <typename... Args>
......@@ -518,7 +615,7 @@ protected:
}
template <typename Invocation, typename DeviceAdapter>
VTKM_CONT void BasicInvoke(const Invocation& invocation,
VTKM_CONT void BasicInvoke(Invocation& invocation,
vtkm::Id numInstances,
DeviceAdapter device) const
{
......@@ -527,7 +624,7 @@ protected:
}
template <typename Invocation, typename DeviceAdapter>
VTKM_CONT void BasicInvoke(const Invocation& invocation,
VTKM_CONT void BasicInvoke(Invocation& invocation,
vtkm::Id2 dimensions,
DeviceAdapter device) const
{
......@@ -535,7 +632,7 @@ protected:
}
template <typename Invocation, typename DeviceAdapter>
VTKM_CONT void BasicInvoke(const Invocation& invocation,
VTKM_CONT void BasicInvoke(Invocation& invocation,
vtkm::Id3 dimensions,
DeviceAdapter device) const
{
......@@ -555,7 +652,7 @@ private:
typename InputRangeType,
typename OutputRangeType,
typename DeviceAdapter>
VTKM_CONT void InvokeTransportParameters(const Invocation& invocation,
VTKM_CONT void InvokeTransportParameters(Invocation& invocation,
const InputRangeType& inputRange,
OutputRangeType&& outputRange,
DeviceAdapter device) const
......@@ -570,7 +667,7 @@ private:
// static transform of the FunctionInterface to call the transport on each
// argument and return the corresponding execution environment object.
using ParameterInterfaceType = typename Invocation::ParameterInterface;
const ParameterInterfaceType& parameters = invocation.Parameters;
ParameterInterfaceType& parameters = invocation.Parameters;
using TransportFunctorType =
detail::DispatcherBaseTransportFunctor<typename Invocation::ControlInterface,
......
......@@ -33,16 +33,33 @@ using Device = vtkm::cont::DeviceAdapterTagSerial;
static constexpr vtkm::Id ARRAY_SIZE = 10;
struct TestExecObject
struct TestExecObjectIn
{
VTKM_EXEC_CONT
TestExecObject()
TestExecObjectIn()
: Array(nullptr)
{
}
VTKM_EXEC_CONT
TestExecObject(vtkm::Id* array)
TestExecObjectIn(const vtkm::Id* array)
: Array(array)
{
}
const vtkm::Id* Array;
};
struct TestExecObjectOut
{
VTKM_EXEC_CONT
TestExecObjectOut()
: Array(nullptr)
{
}
VTKM_EXEC_CONT
TestExecObjectOut(vtkm::Id* array)
: Array(array)
{
}
......@@ -85,7 +102,10 @@ struct TestExecObjectTypeBad
struct TestTypeCheckTag
{
};
struct TestTransportTag
struct TestTransportTagIn
{
};
struct TestTransportTagOut
{
};
struct TestFetchTagInput
......@@ -105,25 +125,43 @@ namespace arg
{
template <>
struct TypeCheck<TestTypeCheckTag, vtkm::Id*>
struct TypeCheck<TestTypeCheckTag, std::vector<vtkm::Id>>
{
static constexpr bool value = true;
};
template <>
struct Transport<TestTransportTag, vtkm::Id*, Device>
struct Transport<TestTransportTagIn, std::vector<vtkm::Id>,