Commit be5b51fb authored by Robert Maynard's avatar Robert Maynard Committed by Kitware Robot

Merge topic 'invoker_support_scatter'

18a0cd35 vtkm::worklet::Invoker now supports scatter types
afc3f530 Remove unneeded ScatterType as it was the default
a1ea509f All scatter types now inherit from a common base
77378993 DispatcherBase: Simplify remove_cvref and remove_pointer_and_decay.
Acked-by: Kitware Robot's avatarKitware Robot <kwrobot@kitware.com>
Acked-by: Kenneth Moreland's avatarKenneth Moreland <kmorel@sandia.gov>
Merge-request: !1673
parents bd626dac 18a0cd35
# `vtkm::worklet::Invoker` now able to worklets that have non-default scatter type
This change allows the `Invoker` class to support launching worklets that require
a custom scatter operation. This is done by providing the scatter as the second
argument when launch a worklet with the `()` operator.
The following example shows a scatter being provided with a worklet launch.
```cpp
struct CheckTopology : vtkm::worklet::WorkletMapPointToCell
{
using ControlSignature = void(CellSetIn cellset, FieldOutCell);
using ExecutionSignature = _2(FromIndices);
using ScatterType = vtkm::worklet::ScatterPermutation<>;
...
};
vtkm::worklet::Ivoker invoke;
invoke( CheckTopology{}, vtkm::worklet::ScatterPermutation{}, cellset, result );
```
......@@ -808,8 +808,6 @@ public:
using ExecutionSignature = void(_2, _3);
using ScatterType = vtkm::worklet::ScatterIdentity;
template <typename MappedValueVecType, typename MappedValueType>
VTKM_EXEC void operator()(const MappedValueVecType& toReduce, MappedValueType& centroid) const
{
......
......@@ -52,20 +52,46 @@ struct Invoker
{
}
/// Launch the worklet that is provided as the first parameter. The additional
/// parameters are the ControlSignature arguments for the worklet.
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet, typename... Args>
inline void operator()(Worklet&& worklet, Args&&... args) const
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
{
using WorkletType = typename std::decay<Worklet>::type;
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet);
DispatcherType dispatcher(worklet, scatter);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
}
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
!std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
{
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<T>(t), std::forward<Args>(args)...);
}
/// Get the device adapter that this Invoker is bound too
///
vtkm::cont::DeviceAdapterId GetDevice() const { return DeviceId; }
......
......@@ -10,6 +10,7 @@
#ifndef vtk_m_worklet_ScatterCounting_h
#define vtk_m_worklet_ScatterCounting_h
#include <vtkm/worklet/internal/ScatterBase.h>
#include <vtkm/worklet/vtkm_worklet_export.h>
#include <vtkm/cont/VariantArrayHandle.h>
......@@ -40,7 +41,7 @@ struct ScatterCountingBuilder;
/// taken in the constructor and the index arrays are derived from that. So
/// changing the counts after the scatter is created will have no effect.
///
struct VTKM_WORKLET_EXPORT ScatterCounting
struct VTKM_WORKLET_EXPORT ScatterCounting : internal::ScatterBase
{
struct CountTypes : vtkm::ListTagBase<vtkm::Int64,
vtkm::Int32,
......
......@@ -12,6 +12,7 @@
#include <vtkm/cont/ArrayHandleConstant.h>
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/worklet/internal/ScatterBase.h>
namespace vtkm
{
......@@ -26,7 +27,7 @@ namespace worklet
/// element generates one output element associated with it. This is the
/// default for basic maps.
///
struct ScatterIdentity
struct ScatterIdentity : internal::ScatterBase
{
using OutputToInputMapType = vtkm::cont::ArrayHandleIndex;
VTKM_CONT
......
......@@ -12,6 +12,7 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleConstant.h>
#include <vtkm/worklet/internal/ScatterBase.h>
namespace vtkm
{
......@@ -28,7 +29,7 @@ namespace worklet
/// can be duplicates. Note that even with duplicates the VistIndex is always 0.
///
template <typename PermutationStorage = VTKM_DEFAULT_STORAGE_TAG>
class ScatterPermutation
class ScatterPermutation : public internal::ScatterBase
{
private:
using PermutationArrayHandle = vtkm::cont::ArrayHandle<vtkm::Id, PermutationStorage>;
......
......@@ -13,6 +13,7 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayHandleImplicit.h>
#include <vtkm/worklet/internal/ScatterBase.h>
namespace vtkm
{
......@@ -49,7 +50,7 @@ struct FunctorDiv
/// elements are grouped by the input associated.
///
template <vtkm::IdComponent NumOutputsPerInput>
struct ScatterUniform
struct ScatterUniform : internal::ScatterBase
{
VTKM_CONT ScatterUniform() = default;
......
......@@ -76,7 +76,7 @@
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/ArrayHandleReverse.h>
#include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/Invoker.h>
#include <vtkm/worklet/ScatterCounting.h>
#include <vtkm/BinaryPredicates.h>
......
......@@ -316,16 +316,15 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Setup the ScatterCounting worklets needed to expand the ReduceByKeyResults
vtkm::worklet::ScatterCounting scatter(particlesPerHalo);
vtkm::worklet::DispatcherMapField<ScatterWorklet<vtkm::Id>> scatterWorkletIdDispatcher(scatter);
vtkm::worklet::DispatcherMapField<ScatterWorklet<T>> scatterWorkletDispatcher(scatter);
vtkm::worklet::Invoker invoke;
// Calculate the minimum particle index per halo id and scatter
DeviceAlgorithm::ScanExclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, minParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, minParticle);
// Calculate the maximum particle index per halo id and scatter
DeviceAlgorithm::ScanInclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, maxParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, maxParticle);
using IdArrayType = vtkm::cont::ArrayHandle<vtkm::Id>;
vtkm::cont::ArrayHandleTransform<IdArrayType, ScaleBiasFunctor<vtkm::Id>> scaleBias =
......@@ -354,7 +353,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Find minimum potential for all particles in a halo and scatter
DeviceAlgorithm::ReduceByKey(haloId, potential, uniqueHaloIds, tempT, vtkm::Minimum());
scatterWorkletDispatcher.Invoke(tempT, minPotential);
invoke(ScatterWorklet<T>{}, scatter, tempT, minPotential);
#ifdef DEBUG_PRINT
DebugPrint("potential", potential);
DebugPrint("minPotential", minPotential);
......@@ -371,7 +370,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
vtkm::cont::ArrayHandle<vtkm::Id> minIndx;
minIndx.Allocate(nParticles);
DeviceAlgorithm::ReduceByKey(haloId, mbpId, uniqueHaloIds, minIndx, vtkm::Maximum());
scatterWorkletIdDispatcher.Invoke(minIndx, mbpId);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, minIndx, mbpId);
// Resort particle ids and mbpId to starting order
vtkm::cont::ArrayHandle<vtkm::Id> savePartId;
......
......@@ -10,6 +10,7 @@
set(headers
DispatcherBase.h
ScatterBase.h
TriangulateTables.h
WorkletBase.h
)
......
......@@ -153,15 +153,16 @@ struct ReportValueOnError<Value, true> : std::true_type
};
template <typename T>
struct remove_pointer_and_decay : std::remove_pointer<typename std::decay<T>::type>
{
};
using remove_pointer_and_decay = typename std::remove_pointer<typename std::decay<T>::type>::type;
template <typename T>
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
// Is designed as a brigand fold operation.
template <typename Type, typename State>
struct DetermineIfHasDynamicParameter
{
using T = typename std::remove_pointer<Type>::type;
using T = remove_pointer_and_decay<Type>;
using DynamicTag = typename vtkm::cont::internal::DynamicTransformTraits<T>::DynamicTag;
using isDynamic =
typename std::is_same<DynamicTag, vtkm::cont::internal::DynamicTransformTagCastAndCall>::type;
......@@ -314,7 +315,7 @@ struct DispatcherBaseTransportFunctor
{
using TransportTag =
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
using T = typename remove_pointer_and_decay<ControlParameter>::type;
using T = remove_pointer_and_decay<ControlParameter>;
using TransportType = typename vtkm::cont::arg::Transport<TransportTag, T, Device>;
using type = typename TransportType::ExecObjectType;
};
......@@ -326,7 +327,7 @@ struct DispatcherBaseTransportFunctor
{
using TransportTag =
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
using T = typename remove_pointer_and_decay<ControlParameter>::type;
using T = remove_pointer_and_decay<ControlParameter>;
vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
not_nullptr(invokeData, Index);
......@@ -412,7 +413,7 @@ struct for_each_dynamic_arg
void operator()(const Trampoline& trampoline, ContParams&& sig, T&& t, Args&&... args) const
{
//Determine that state of T when it is either a `cons&` or a `* const&`
using Type = typename std::remove_pointer<typename std::decay<T>::type>::type;
using Type = remove_pointer_and_decay<T>;
using tag = typename vtkm::cont::internal::DynamicTransformTraits<Type>::DynamicTag;
//convert the first item to a known type
convert_arg<LeftToProcess>(
......@@ -494,7 +495,7 @@ private:
VTKM_CONT void StartInvoke(Args&&... args) const
{
using ParameterInterface =
vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS,
"Dispatcher Invoke called with wrong number of arguments.");
......@@ -540,7 +541,7 @@ private:
VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const
{
using ParameterInterface =
vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
//Nothing requires a conversion from dynamic to static types, so
//next we need to verify that each argument's type is correct. If not
......@@ -561,8 +562,7 @@ private:
static_assert(isAllValid::value == expectedLen::value,
"All arguments failed the TypeCheck pass");
auto fi =
vtkm::internal::make_FunctionInterface<void, typename std::decay<Args>::type...>(args...);
auto fi = vtkm::internal::make_FunctionInterface<void, detail::remove_cvref<Args>...>(args...);
auto ivc = vtkm::internal::Invocation<ParameterInterface,
ControlInterface,
ExecutionInterface,
......
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_worklet_internal_ScatterBase_h
#define vtk_m_worklet_internal_ScatterBase_h
#include <vtkm/internal/ExportMacros.h>
namespace vtkm
{
namespace worklet
{
namespace internal
{
/// Base class for all scatter classes.
///
/// This allows VTK-m to determine when a parameter
/// is a scatter type instead of a worklet parameter.
///
struct VTKM_ALWAYS_EXPORT ScatterBase
{
};
}
}
}
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment