RuntimeDeviceTracker.cxx 9.54 KB
Newer Older
1 2 3 4 5 6 7 8
//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
9
//  Copyright 2016 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 11 12
//  Copyright 2016 UT-Battelle, LLC.
//  Copyright 2016 Los Alamos National Security.
//
13
//  Under the terms of Contract DE-NA0003525 with NTESS,
14 15 16 17 18 19 20 21 22 23 24 25
//  the U.S. Government retains certain rights in this software.
//
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//============================================================================

#include <vtkm/cont/RuntimeDeviceTracker.h>

#include <vtkm/cont/DeviceAdapter.h>
#include <vtkm/cont/DeviceAdapterListTag.h>
#include <vtkm/cont/ErrorBadValue.h>
26
#include <vtkm/cont/internal/DeviceAdapterError.h>
27 28

#include <vtkm/cont/cuda/DeviceAdapterCuda.h>
29
#include <vtkm/cont/serial/DeviceAdapterSerial.h>
30 31
#include <vtkm/cont/tbb/DeviceAdapterTBB.h>

32 33 34 35 36 37 38 39
//Bring in each device adapters runtime class
#include <vtkm/cont/cuda/internal/DeviceAdapterRuntimeDetectorCuda.h>
#include <vtkm/cont/internal/DeviceAdapterError.h>
#include <vtkm/cont/openmp/internal/DeviceAdapterRuntimeDetectorOpenMP.h>
#include <vtkm/cont/serial/internal/DeviceAdapterRuntimeDetectorSerial.h>
#include <vtkm/cont/tbb/internal/DeviceAdapterRuntimeDetectorTBB.h>


40
#include <algorithm>
41
#include <cctype> //for tolower
42 43
#include <map>
#include <mutex>
44
#include <sstream>
45
#include <thread>
46

47 48 49 50 51 52
namespace
{

struct VTKM_NEVER_EXPORT GetDeviceNameFunctor
{
  vtkm::cont::DeviceAdapterNameType* Names;
53
  vtkm::cont::DeviceAdapterNameType* LowerCaseNames;
54 55

  VTKM_CONT
56 57
  GetDeviceNameFunctor(vtkm::cont::DeviceAdapterNameType* names,
                       vtkm::cont::DeviceAdapterNameType* lower)
58
    : Names(names)
59
    , LowerCaseNames(lower)
60 61
  {
    std::fill_n(this->Names, VTKM_MAX_DEVICE_ADAPTER_ID, "InvalidDeviceId");
62
    std::fill_n(this->LowerCaseNames, VTKM_MAX_DEVICE_ADAPTER_ID, "invaliddeviceid");
63 64 65 66 67
  }

  template <typename Device>
  VTKM_CONT void operator()(Device device)
  {
68 69 70 71
    auto lowerCaseFunc = [](char c) {
      return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
    };

72 73 74 75
    auto id = device.GetValue();

    if (id > 0 && id < VTKM_MAX_DEVICE_ADAPTER_ID)
    {
76 77 78 79
      auto name = vtkm::cont::DeviceAdapterTraits<Device>::GetName();
      this->Names[id] = name;
      std::transform(name.begin(), name.end(), name.begin(), lowerCaseFunc);
      this->LowerCaseNames[id] = name;
80 81 82 83
    }
  }
};

84 85 86 87
#if !(defined(VTKM_CLANG) && (__apple_build_version__ < 8000000))
thread_local static vtkm::cont::RuntimeDeviceTracker runtimeDeviceTracker;
#endif

88 89
} // end anon namespace

90 91 92 93
namespace vtkm
{
namespace cont
{
94

95 96
namespace detail
{
97 98 99 100

struct RuntimeDeviceTrackerInternals
{
  bool RuntimeValid[VTKM_MAX_DEVICE_ADAPTER_ID];
101
  DeviceAdapterNameType DeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
102
  DeviceAdapterNameType LowerCaseDeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
103
};
104 105 106 107 108 109 110 111 112 113 114 115 116

struct RuntimeDeviceTrackerFunctor
{
  template <typename DeviceAdapter>
  VTKM_CONT void operator()(DeviceAdapter, DeviceAdapterId id, RuntimeDeviceTracker* rdt)
  {
    vtkm::cont::RuntimeDeviceInformation runtimeDevice;
    if (DeviceAdapter() == id)
    {
      rdt->ForceDeviceImpl(DeviceAdapter(), runtimeDevice.Exists(DeviceAdapter()));
    }
  }
};
117 118 119 120
}

VTKM_CONT
RuntimeDeviceTracker::RuntimeDeviceTracker()
121
  : Internals(std::make_shared<detail::RuntimeDeviceTrackerInternals>())
122
{
123
  GetDeviceNameFunctor functor(this->Internals->DeviceNames, this->Internals->LowerCaseDeviceNames);
124 125
  vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG());

126 127 128 129 130
  this->Reset();
}

VTKM_CONT
RuntimeDeviceTracker::~RuntimeDeviceTracker()
131 132
{
}
133 134

VTKM_CONT
135
void RuntimeDeviceTracker::CheckDevice(vtkm::cont::DeviceAdapterId deviceId) const
136
{
137
  if (!deviceId.IsValueValid())
138 139
  {
    std::stringstream message;
140 141
    message << "Device '" << deviceId.GetName() << "' has invalid ID of "
            << (int)deviceId.GetValue();
142 143 144 145 146
    throw vtkm::cont::ErrorBadValue(message.str());
  }
}

VTKM_CONT
147
bool RuntimeDeviceTracker::CanRunOnImpl(vtkm::cont::DeviceAdapterId deviceId) const
148
{
149
  this->CheckDevice(deviceId);
150
  return this->Internals->RuntimeValid[deviceId.GetValue()];
151 152 153
}

VTKM_CONT
154
void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId, bool state)
155
{
156
  this->CheckDevice(deviceId);
157 158 159

  VTKM_LOG_S(vtkm::cont::LogLevel::Info,
             "Setting device '" << deviceId.GetName() << "' to " << state);
160
  this->Internals->RuntimeValid[deviceId.GetValue()] = state;
161 162
}

163 164
namespace
{
165 166 167 168 169 170

struct VTKM_NEVER_EXPORT RuntimeDeviceTrackerResetFunctor
{
  vtkm::cont::RuntimeDeviceTracker Tracker;

  VTKM_CONT
171
  RuntimeDeviceTrackerResetFunctor(const vtkm::cont::RuntimeDeviceTracker& tracker)
172
    : Tracker(tracker)
173 174
  {
  }
175

176
  template <typename Device>
177
  VTKM_CONT void operator()(Device device)
178
  {
179
    this->Tracker.ResetDevice(device);
180 181 182 183 184 185 186
  }
};
}

VTKM_CONT
void RuntimeDeviceTracker::Reset()
{
187
  std::fill_n(this->Internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, false);
188 189 190 191 192

  RuntimeDeviceTrackerResetFunctor functor(*this);
  vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG());
}

193
VTKM_CONT
194
vtkm::cont::RuntimeDeviceTracker RuntimeDeviceTracker::DeepCopy() const
195
{
196
  return vtkm::cont::RuntimeDeviceTracker(this->Internals);
197 198 199
}

VTKM_CONT
200
void RuntimeDeviceTracker::DeepCopy(const vtkm::cont::RuntimeDeviceTracker& src)
201
{
202 203
  std::copy_n(
    src.Internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeValid);
204 205
}

206 207 208 209 210 211 212
VTKM_CONT
RuntimeDeviceTracker::RuntimeDeviceTracker(
  const std::shared_ptr<detail::RuntimeDeviceTrackerInternals>& internals)
  : Internals(std::make_shared<detail::RuntimeDeviceTrackerInternals>())
{
  std::copy_n(internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeValid);
  std::copy_n(internals->DeviceNames, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->DeviceNames);
213 214 215
  std::copy_n(internals->LowerCaseDeviceNames,
              VTKM_MAX_DEVICE_ADAPTER_ID,
              this->Internals->LowerCaseDeviceNames);
216 217
}

218
VTKM_CONT
219
void RuntimeDeviceTracker::ForceDeviceImpl(vtkm::cont::DeviceAdapterId deviceId, bool runtimeExists)
220 221 222 223
{
  if (!runtimeExists)
  {
    std::stringstream message;
224
    message << "Cannot force to device '" << deviceId.GetName()
225 226 227
            << "' because that device is not available on this system";
    throw vtkm::cont::ErrorBadValue(message.str());
  }
228
  this->CheckDevice(deviceId);
229

230 231 232
  VTKM_LOG_S(vtkm::cont::LogLevel::Info,
             "Forcing execution to occur on device '" << deviceId.GetName() << "'");

233 234
  std::fill_n(this->Internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, false);

235
  this->Internals->RuntimeValid[deviceId.GetValue()] = runtimeExists;
236 237
}

238 239 240 241 242 243 244
VTKM_CONT
void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId id)
{
  detail::RuntimeDeviceTrackerFunctor functor;
  vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG(), id, this);
}

245 246 247 248 249 250 251 252 253 254 255 256
VTKM_CONT
DeviceAdapterNameType RuntimeDeviceTracker::GetDeviceName(DeviceAdapterId device) const
{
  auto id = device.GetValue();

  if (id < 0)
  {
    switch (id)
    {
      case VTKM_DEVICE_ADAPTER_ERROR:
        return vtkm::cont::DeviceAdapterTraits<vtkm::cont::DeviceAdapterTagError>::GetName();
      case VTKM_DEVICE_ADAPTER_UNDEFINED:
257
        return vtkm::cont::DeviceAdapterTraits<vtkm::cont::DeviceAdapterTagUndefined>::GetName();
258 259 260 261 262 263 264 265 266
      default:
        break;
    }
  }
  else if (id >= VTKM_MAX_DEVICE_ADAPTER_ID)
  {
    switch (id)
    {
      case VTKM_DEVICE_ADAPTER_ANY:
267
        return vtkm::cont::DeviceAdapterTraits<vtkm::cont::DeviceAdapterTagAny>::GetName();
268 269 270 271 272 273 274 275 276 277 278 279
      default:
        break;
    }
  }
  else // id is valid:
  {
    return this->Internals->DeviceNames[id];
  }

  // Device 0 is invalid:
  return this->Internals->DeviceNames[0];
}
280 281 282 283

VTKM_CONT
DeviceAdapterId RuntimeDeviceTracker::GetDeviceAdapterId(DeviceAdapterNameType name) const
{
284 285 286 287 288 289 290 291 292
  // The GetDeviceAdapterId call is case-insensitive so transform the name to be lower case
  // as that is how we cache the case-insensitive version.
  auto lowerCaseFunc = [](char c) {
    return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
  };
  std::transform(name.begin(), name.end(), name.begin(), lowerCaseFunc);

  //lower-case the name here
  if (name == "any")
293 294 295
  {
    return vtkm::cont::DeviceAdapterTagAny{};
  }
296
  else if (name == "error")
297 298 299
  {
    return vtkm::cont::DeviceAdapterTagError{};
  }
300
  else if (name == "undefined")
301 302 303 304 305 306
  {
    return vtkm::cont::DeviceAdapterTagUndefined{};
  }

  for (vtkm::Int8 id = 0; id < VTKM_MAX_DEVICE_ADAPTER_ID; ++id)
  {
307
    if (name == this->Internals->LowerCaseDeviceNames[id])
308 309 310 311 312 313 314
    {
      return vtkm::cont::make_DeviceAdapterId(id);
    }
  }

  return vtkm::cont::DeviceAdapterTagUndefined{};
}
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

VTKM_CONT
vtkm::cont::RuntimeDeviceTracker GetGlobalRuntimeDeviceTracker()
{
#if defined(VTKM_CLANG) && (__apple_build_version__ < 8000000)
  static std::mutex mtx;
  static std::map<std::thread::id, vtkm::cont::RuntimeDeviceTracker> globalTrackers;
  std::thread::id this_id = std::this_thread::get_id();

  std::unique_lock<std::mutex> lock(mtx);
  auto iter = globalTrackers.find(this_id);
  if (iter != globalTrackers.end())
  {
    return iter->second;
  }
  else
  {
    vtkm::cont::RuntimeDeviceTracker tracker;
    globalTrackers[this_id] = tracker;
    return tracker;
  }
#else
  return runtimeDeviceTracker;
#endif
}
340 341
}
} // namespace vtkm::cont