Commit 18a0cd35 authored by Robert Maynard's avatar Robert Maynard

vtkm::worklet::Invoker now supports scatter types

Fixes #297
parent afc3f530
# `vtkm::worklet::Invoker` now able to worklets that have non-default scatter type
This change allows the `Invoker` class to support launching worklets that require
a custom scatter operation. This is done by providing the scatter as the second
argument when launch a worklet with the `()` operator.
The following example shows a scatter being provided with a worklet launch.
```cpp
struct CheckTopology : vtkm::worklet::WorkletMapPointToCell
{
using ControlSignature = void(CellSetIn cellset, FieldOutCell);
using ExecutionSignature = _2(FromIndices);
using ScatterType = vtkm::worklet::ScatterPermutation<>;
...
};
vtkm::worklet::Ivoker invoke;
invoke( CheckTopology{}, vtkm::worklet::ScatterPermutation{}, cellset, result );
```
......@@ -52,20 +52,46 @@ struct Invoker
{
}
/// Launch the worklet that is provided as the first parameter. The additional
/// parameters are the ControlSignature arguments for the worklet.
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet, typename... Args>
inline void operator()(Worklet&& worklet, Args&&... args) const
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
{
using WorkletType = typename std::decay<Worklet>::type;
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet);
DispatcherType dispatcher(worklet, scatter);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
}
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
!std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
{
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<T>(t), std::forward<Args>(args)...);
}
/// Get the device adapter that this Invoker is bound too
///
vtkm::cont::DeviceAdapterId GetDevice() const { return DeviceId; }
......
......@@ -76,7 +76,7 @@
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/ArrayHandleReverse.h>
#include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/Invoker.h>
#include <vtkm/worklet/ScatterCounting.h>
#include <vtkm/BinaryPredicates.h>
......
......@@ -316,16 +316,15 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Setup the ScatterCounting worklets needed to expand the ReduceByKeyResults
vtkm::worklet::ScatterCounting scatter(particlesPerHalo);
vtkm::worklet::DispatcherMapField<ScatterWorklet<vtkm::Id>> scatterWorkletIdDispatcher(scatter);
vtkm::worklet::DispatcherMapField<ScatterWorklet<T>> scatterWorkletDispatcher(scatter);
vtkm::worklet::Invoker invoke;
// Calculate the minimum particle index per halo id and scatter
DeviceAlgorithm::ScanExclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, minParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, minParticle);
// Calculate the maximum particle index per halo id and scatter
DeviceAlgorithm::ScanInclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, maxParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, maxParticle);
using IdArrayType = vtkm::cont::ArrayHandle<vtkm::Id>;
vtkm::cont::ArrayHandleTransform<IdArrayType, ScaleBiasFunctor<vtkm::Id>> scaleBias =
......@@ -354,7 +353,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Find minimum potential for all particles in a halo and scatter
DeviceAlgorithm::ReduceByKey(haloId, potential, uniqueHaloIds, tempT, vtkm::Minimum());
scatterWorkletDispatcher.Invoke(tempT, minPotential);
invoke(ScatterWorklet<T>{}, scatter, tempT, minPotential);
#ifdef DEBUG_PRINT
DebugPrint("potential", potential);
DebugPrint("minPotential", minPotential);
......@@ -371,7 +370,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
vtkm::cont::ArrayHandle<vtkm::Id> minIndx;
minIndx.Allocate(nParticles);
DeviceAlgorithm::ReduceByKey(haloId, mbpId, uniqueHaloIds, minIndx, vtkm::Maximum());
scatterWorkletIdDispatcher.Invoke(minIndx, mbpId);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, minIndx, mbpId);
// Resort particle ids and mbpId to starting order
vtkm::cont::ArrayHandle<vtkm::Id> savePartId;
......
......@@ -158,7 +158,6 @@ using remove_pointer_and_decay = typename std::remove_pointer<typename std::deca
template <typename T>
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
// Is designed as a brigand fold operation.
template <typename Type, typename State>
struct DetermineIfHasDynamicParameter
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment