WrappedOperators.h 5.48 KB
Newer Older
1 2 3 4 5 6 7 8
//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
Kenneth Moreland's avatar
Kenneth Moreland committed
9
//  Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 11 12
//  Copyright 2014 UT-Battelle, LLC.
//  Copyright 2014 Los Alamos National Security.
//
Kenneth Moreland's avatar
Kenneth Moreland committed
13
//  Under the terms of Contract DE-NA0003525 with NTESS,
14 15 16 17 18 19 20 21 22
//  the U.S. Government retains certain rights in this software.
//
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//============================================================================
#ifndef vtk_m_exec_cuda_internal_WrappedOperators_h
#define vtk_m_exec_cuda_internal_WrappedOperators_h

23
#include <vtkm/BinaryPredicates.h>
24
#include <vtkm/Pair.h>
25
#include <vtkm/Types.h>
26
#include <vtkm/exec/cuda/internal/IteratorFromArrayPortal.h>
27
#include <vtkm/internal/ExportMacros.h>
28

29
// Disable warnings we check vtkm for but Thrust does not.
30
VTKM_THIRDPARTY_PRE_INCLUDE
31
#include <thrust/system/cuda/memory.h>
32
VTKM_THIRDPARTY_POST_INCLUDE
33

34 35 36 37 38 39 40 41
namespace vtkm
{
namespace exec
{
namespace cuda
{
namespace internal
{
42 43 44 45 46

// Unary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// PortalValue which happen when passed an input array that
// is implicit.
47 48
template <typename T_, typename Function>
struct WrappedUnaryPredicate
49
{
50
  using T = typename std::remove_const<T_>::type;
51 52

  //make typedefs that thust expects unary operators to have
53 54
  using first_argument_type = T;
  using result_type = bool;
55 56 57

  Function m_f;

58
  VTKM_EXEC
59 60
  WrappedUnaryPredicate()
    : m_f()
61 62
  {
  }
63

64
  VTKM_CONT
65
  WrappedUnaryPredicate(const Function& f)
66 67 68 69
    : m_f(f)
  {
  }

70 71 72 73
  VTKM_EXEC bool operator()(const T& x) const { return m_f(x); }

  template <typename U>
  VTKM_EXEC bool operator()(const PortalValue<U>& x) const
74 75 76 77
  {
    return m_f((T)x);
  }

78
  VTKM_EXEC bool operator()(const T* x) const { return m_f(*x); }
79 80
};

81 82 83 84
// Binary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// PortalValue which happen when passed an input array that
// is implicit.
85
template <typename T_, typename Function>
86
struct WrappedBinaryOperator
87
{
88
  using T = typename std::remove_const<T_>::type;
89 90

  //make typedefs that thust expects binary operators to have
91 92 93
  using first_argument_type = T;
  using second_argument_type = T;
  using result_type = T;
94

95 96
  Function m_f;

97
  VTKM_EXEC
98 99
  WrappedBinaryOperator()
    : m_f()
100 101
  {
  }
102

103
  VTKM_CONT
104
  WrappedBinaryOperator(const Function& f)
105 106 107 108
    : m_f(f)
  {
  }

109
  VTKM_EXEC T operator()(const T& x, const T& y) const { return m_f(x, y); }
110

111 112
  template <typename U>
  VTKM_EXEC T operator()(const T& x, const PortalValue<U>& y) const
113
  {
114 115 116 117
    // to support proper implicit conversion, and avoid overload
    // ambiguities.
    T conv_y = y;
    return m_f(x, conv_y);
118 119
  }

120 121
  template <typename U>
  VTKM_EXEC T operator()(const PortalValue<U>& x, const T& y) const
122
  {
123 124
    T conv_x = x;
    return m_f(conv_x, y);
125 126
  }

127 128
  template <typename U, typename V>
  VTKM_EXEC T operator()(const PortalValue<U>& x, const PortalValue<V>& y) const
129
  {
130 131 132
    T conv_x = x;
    T conv_y = y;
    return m_f(conv_x, conv_y);
133 134
  }

135
  VTKM_EXEC T operator()(const T* const x, const T& y) const { return m_f(*x, y); }
136

137
  VTKM_EXEC T operator()(const T& x, const T* const y) const { return m_f(x, *y); }
138

139
  VTKM_EXEC T operator()(const T* const x, const T* const y) const { return m_f(*x, *y); }
140 141
};

142
template <typename T_, typename Function>
143 144
struct WrappedBinaryPredicate
{
145
  using T = typename std::remove_const<T_>::type;
146 147

  //make typedefs that thust expects binary operators to have
148 149 150
  using first_argument_type = T;
  using second_argument_type = T;
  using result_type = bool;
151 152 153

  Function m_f;

154
  VTKM_EXEC
155 156
  WrappedBinaryPredicate()
    : m_f()
157 158
  {
  }
159

160
  VTKM_CONT
161
  WrappedBinaryPredicate(const Function& f)
162 163 164 165
    : m_f(f)
  {
  }

166
  VTKM_EXEC bool operator()(const T& x, const T& y) const { return m_f(x, y); }
167

168 169
  template <typename U>
  VTKM_EXEC bool operator()(const T& x, const PortalValue<U>& y) const
170 171 172 173
  {
    return m_f(x, (T)y);
  }

174 175
  template <typename U>
  VTKM_EXEC bool operator()(const PortalValue<U>& x, const T& y) const
176 177 178 179
  {
    return m_f((T)x, y);
  }

180 181
  template <typename U, typename V>
  VTKM_EXEC bool operator()(const PortalValue<U>& x, const PortalValue<V>& y) const
182 183 184 185
  {
    return m_f((T)x, (T)y);
  }

186
  VTKM_EXEC bool operator()(const T* const x, const T& y) const { return m_f(*x, y); }
187

188
  VTKM_EXEC bool operator()(const T& x, const T* const y) const { return m_f(x, *y); }
189

190
  VTKM_EXEC bool operator()(const T* const x, const T* const y) const { return m_f(*x, *y); }
191 192 193 194 195 196
};
}
}
}
} //namespace vtkm::exec::cuda::internal

197 198 199 200
namespace thrust
{
namespace detail
{
201 202 203 204 205 206
//
// We tell Thrust that our WrappedBinaryOperator is commutative so that we
// activate numerous fast paths inside thrust which are only available when
// the binary functor is commutative and the T type is is_arithmetic
//
//
207 208 209 210 211 212 213
template <typename T, typename F>
struct is_commutative<vtkm::exec::cuda::internal::WrappedBinaryOperator<T, F>>
  : public thrust::detail::is_arithmetic<T>
{
};
}
} //namespace thrust::detail
214

215
#endif //vtk_m_exec_cuda_internal_WrappedOperators_h