WrappedOperators.h 5.37 KB
Newer Older
1 2 3 4
//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
5
//
6 7 8 9 10 11 12
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_exec_cuda_internal_WrappedOperators_h
#define vtk_m_exec_cuda_internal_WrappedOperators_h

13
#include <vtkm/BinaryPredicates.h>
14
#include <vtkm/Pair.h>
15
#include <vtkm/Types.h>
16
#include <vtkm/exec/cuda/internal/IteratorFromArrayPortal.h>
17
#include <vtkm/internal/ExportMacros.h>
18

19
// Disable warnings we check vtkm for but Thrust does not.
20
#include <vtkm/exec/cuda/internal/ThrustPatches.h>
21
VTKM_THIRDPARTY_PRE_INCLUDE
22
#include <thrust/system/cuda/memory.h>
23
VTKM_THIRDPARTY_POST_INCLUDE
24

25 26 27 28 29 30 31 32
namespace vtkm
{
namespace exec
{
namespace cuda
{
namespace internal
{
33 34 35

// Unary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
36
// ArrayPortalValueReference which happen when passed an input array that
37
// is implicit.
38 39
template <typename T_, typename Function>
struct WrappedUnaryPredicate
40
{
41
  using T = typename std::remove_const<T_>::type;
42 43

  //make typedefs that thust expects unary operators to have
44 45
  using first_argument_type = T;
  using result_type = bool;
46 47 48

  Function m_f;

49
  VTKM_EXEC
50 51
  WrappedUnaryPredicate()
    : m_f()
52 53
  {
  }
54

55
  VTKM_CONT
56
  WrappedUnaryPredicate(const Function& f)
57 58 59 60
    : m_f(f)
  {
  }

61 62 63
  VTKM_EXEC bool operator()(const T& x) const { return m_f(x); }

  template <typename U>
64
  VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x) const
65
  {
66
    return m_f(x.Get());
67 68
  }

69
  VTKM_EXEC bool operator()(const T* x) const { return m_f(*x); }
70 71
};

72 73
// Binary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
74
// ArrayPortalValueReference which happen when passed an input array that
75
// is implicit.
76
template <typename T_, typename Function>
77
struct WrappedBinaryOperator
78
{
79
  using T = typename std::remove_const<T_>::type;
80 81

  //make typedefs that thust expects binary operators to have
82 83 84
  using first_argument_type = T;
  using second_argument_type = T;
  using result_type = T;
85

86 87
  Function m_f;

88
  VTKM_EXEC
89 90
  WrappedBinaryOperator()
    : m_f()
91 92
  {
  }
93

94
  VTKM_CONT
95
  WrappedBinaryOperator(const Function& f)
96 97 98 99
    : m_f(f)
  {
  }

100
  VTKM_EXEC T operator()(const T& x, const T& y) const { return m_f(x, y); }
101

102
  template <typename U>
103
  VTKM_EXEC T operator()(const T& x, const vtkm::internal::ArrayPortalValueReference<U>& y) const
104
  {
105 106
    // to support proper implicit conversion, and avoid overload
    // ambiguities.
107
    return m_f(x, y.Get());
108 109
  }

110
  template <typename U>
111
  VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference<U>& x, const T& y) const
112
  {
113
    return m_f(x.Get(), y);
114 115
  }

116
  template <typename U, typename V>
117 118
  VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference<U>& x,
                         const vtkm::internal::ArrayPortalValueReference<V>& y) const
119
  {
120
    return m_f(x.Get(), y.Get());
121 122
  }

123
  VTKM_EXEC T operator()(const T* const x, const T& y) const { return m_f(*x, y); }
124

125
  VTKM_EXEC T operator()(const T& x, const T* const y) const { return m_f(x, *y); }
126

127
  VTKM_EXEC T operator()(const T* const x, const T* const y) const { return m_f(*x, *y); }
128 129
};

130
template <typename T_, typename Function>
131 132
struct WrappedBinaryPredicate
{
133
  using T = typename std::remove_const<T_>::type;
134 135

  //make typedefs that thust expects binary operators to have
136 137 138
  using first_argument_type = T;
  using second_argument_type = T;
  using result_type = bool;
139 140 141

  Function m_f;

142
  VTKM_EXEC
143 144
  WrappedBinaryPredicate()
    : m_f()
145 146
  {
  }
147

148
  VTKM_CONT
149
  WrappedBinaryPredicate(const Function& f)
150 151 152 153
    : m_f(f)
  {
  }

154
  VTKM_EXEC bool operator()(const T& x, const T& y) const { return m_f(x, y); }
155

156
  template <typename U>
157
  VTKM_EXEC bool operator()(const T& x, const vtkm::internal::ArrayPortalValueReference<U>& y) const
158
  {
159
    return m_f(x, y.Get());
160 161
  }

162
  template <typename U>
163
  VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x, const T& y) const
164
  {
165
    return m_f(x.Get(), y);
166 167
  }

168
  template <typename U, typename V>
169 170
  VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x,
                            const vtkm::internal::ArrayPortalValueReference<V>& y) const
171
  {
172
    return m_f(x.Get(), y.Get());
173 174
  }

175
  VTKM_EXEC bool operator()(const T* const x, const T& y) const { return m_f(*x, y); }
176

177
  VTKM_EXEC bool operator()(const T& x, const T* const y) const { return m_f(x, *y); }
178

179
  VTKM_EXEC bool operator()(const T* const x, const T* const y) const { return m_f(*x, *y); }
180 181 182 183 184 185
};
}
}
}
} //namespace vtkm::exec::cuda::internal

186 187 188 189
namespace thrust
{
namespace detail
{
190 191 192 193 194 195
//
// We tell Thrust that our WrappedBinaryOperator is commutative so that we
// activate numerous fast paths inside thrust which are only available when
// the binary functor is commutative and the T type is is_arithmetic
//
//
196 197 198 199 200 201 202
template <typename T, typename F>
struct is_commutative<vtkm::exec::cuda::internal::WrappedBinaryOperator<T, F>>
  : public thrust::detail::is_arithmetic<T>
{
};
}
} //namespace thrust::detail
203

204
#endif //vtk_m_exec_cuda_internal_WrappedOperators_h