Commit 5cb52bc8 authored by Robert Maynard's avatar Robert Maynard Committed by Kitware Robot
Browse files

Merge topic 'string_to_device_adapter_case_insensitive'

ce95b8f7

 VTK-m now supports case-insensitive construction of devices from strings.
Acked-by: Kitware Robot's avatarKitware Robot <kwrobot@kitware.com>
Acked-by: Allison Vacanti's avatarAllison Vacanti <allison.vacanti@kitware.com>
Merge-request: !1502
parents 0ae31eb6 ce95b8f7
# VTK-m `vtkm::cont::DeviceAdapterId` construction from string are now case-insensitive
You can now construct a `vtkm::cont::DeviceAdapterId` from a string no matter
the case of it. The following all will construct the same `vtkm::cont::DeviceAdapterId`.
```cpp
vtkm::cont::DeviceAdapterId id1 = vtkm::cont::make_DeviceAdapterId("cuda");
vtkm::cont::DeviceAdapterId id2 = vtkm::cont::make_DeviceAdapterId("CUDA");
vtkm::cont::DeviceAdapterId id3 = vtkm::cont::make_DeviceAdapterId("Cuda");
auto tracker = vtkm::cont::GetGlobalRuntimeDeviceTracker();
vtkm::cont::DeviceAdapterId id4 = tracker.GetDeviceAdapterId("cuda");
vtkm::cont::DeviceAdapterId id5 = tracker.GetDeviceAdapterId("CUDA");
vtkm::cont::DeviceAdapterId id6 = tracker.GetDeviceAdapterId("Cuda");
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <algorithm> #include <algorithm>
#include <cctype> //for tolower
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <sstream> #include <sstream>
...@@ -49,22 +50,33 @@ namespace ...@@ -49,22 +50,33 @@ namespace
struct VTKM_NEVER_EXPORT GetDeviceNameFunctor struct VTKM_NEVER_EXPORT GetDeviceNameFunctor
{ {
vtkm::cont::DeviceAdapterNameType* Names; vtkm::cont::DeviceAdapterNameType* Names;
vtkm::cont::DeviceAdapterNameType* LowerCaseNames;
VTKM_CONT VTKM_CONT
GetDeviceNameFunctor(vtkm::cont::DeviceAdapterNameType* names) GetDeviceNameFunctor(vtkm::cont::DeviceAdapterNameType* names,
vtkm::cont::DeviceAdapterNameType* lower)
: Names(names) : Names(names)
, LowerCaseNames(lower)
{ {
std::fill_n(this->Names, VTKM_MAX_DEVICE_ADAPTER_ID, "InvalidDeviceId"); std::fill_n(this->Names, VTKM_MAX_DEVICE_ADAPTER_ID, "InvalidDeviceId");
std::fill_n(this->LowerCaseNames, VTKM_MAX_DEVICE_ADAPTER_ID, "invaliddeviceid");
} }
template <typename Device> template <typename Device>
VTKM_CONT void operator()(Device device) VTKM_CONT void operator()(Device device)
{ {
auto lowerCaseFunc = [](char c) {
return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
};
auto id = device.GetValue(); auto id = device.GetValue();
if (id > 0 && id < VTKM_MAX_DEVICE_ADAPTER_ID) if (id > 0 && id < VTKM_MAX_DEVICE_ADAPTER_ID)
{ {
this->Names[id] = vtkm::cont::DeviceAdapterTraits<Device>::GetName(); auto name = vtkm::cont::DeviceAdapterTraits<Device>::GetName();
this->Names[id] = name;
std::transform(name.begin(), name.end(), name.begin(), lowerCaseFunc);
this->LowerCaseNames[id] = name;
} }
} }
}; };
...@@ -87,6 +99,7 @@ struct RuntimeDeviceTrackerInternals ...@@ -87,6 +99,7 @@ struct RuntimeDeviceTrackerInternals
{ {
bool RuntimeValid[VTKM_MAX_DEVICE_ADAPTER_ID]; bool RuntimeValid[VTKM_MAX_DEVICE_ADAPTER_ID];
DeviceAdapterNameType DeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID]; DeviceAdapterNameType DeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
DeviceAdapterNameType LowerCaseDeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
}; };
struct RuntimeDeviceTrackerFunctor struct RuntimeDeviceTrackerFunctor
...@@ -107,7 +120,7 @@ VTKM_CONT ...@@ -107,7 +120,7 @@ VTKM_CONT
RuntimeDeviceTracker::RuntimeDeviceTracker() RuntimeDeviceTracker::RuntimeDeviceTracker()
: Internals(std::make_shared<detail::RuntimeDeviceTrackerInternals>()) : Internals(std::make_shared<detail::RuntimeDeviceTrackerInternals>())
{ {
GetDeviceNameFunctor functor(this->Internals->DeviceNames); GetDeviceNameFunctor functor(this->Internals->DeviceNames, this->Internals->LowerCaseDeviceNames);
vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG()); vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG());
this->Reset(); this->Reset();
...@@ -197,6 +210,9 @@ RuntimeDeviceTracker::RuntimeDeviceTracker( ...@@ -197,6 +210,9 @@ RuntimeDeviceTracker::RuntimeDeviceTracker(
{ {
std::copy_n(internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeValid); std::copy_n(internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeValid);
std::copy_n(internals->DeviceNames, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->DeviceNames); std::copy_n(internals->DeviceNames, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->DeviceNames);
std::copy_n(internals->LowerCaseDeviceNames,
VTKM_MAX_DEVICE_ADAPTER_ID,
this->Internals->LowerCaseDeviceNames);
} }
VTKM_CONT VTKM_CONT
...@@ -265,22 +281,30 @@ DeviceAdapterNameType RuntimeDeviceTracker::GetDeviceName(DeviceAdapterId device ...@@ -265,22 +281,30 @@ DeviceAdapterNameType RuntimeDeviceTracker::GetDeviceName(DeviceAdapterId device
VTKM_CONT VTKM_CONT
DeviceAdapterId RuntimeDeviceTracker::GetDeviceAdapterId(DeviceAdapterNameType name) const DeviceAdapterId RuntimeDeviceTracker::GetDeviceAdapterId(DeviceAdapterNameType name) const
{ {
if (name == "Any") // The GetDeviceAdapterId call is case-insensitive so transform the name to be lower case
// as that is how we cache the case-insensitive version.
auto lowerCaseFunc = [](char c) {
return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
};
std::transform(name.begin(), name.end(), name.begin(), lowerCaseFunc);
//lower-case the name here
if (name == "any")
{ {
return vtkm::cont::DeviceAdapterTagAny{}; return vtkm::cont::DeviceAdapterTagAny{};
} }
else if (name == "Error") else if (name == "error")
{ {
return vtkm::cont::DeviceAdapterTagError{}; return vtkm::cont::DeviceAdapterTagError{};
} }
else if (name == "Undefined") else if (name == "undefined")
{ {
return vtkm::cont::DeviceAdapterTagUndefined{}; return vtkm::cont::DeviceAdapterTagUndefined{};
} }
for (vtkm::Int8 id = 0; id < VTKM_MAX_DEVICE_ADAPTER_ID; ++id) for (vtkm::Int8 id = 0; id < VTKM_MAX_DEVICE_ADAPTER_ID; ++id)
{ {
if (name == this->Internals->DeviceNames[id]) if (name == this->Internals->LowerCaseDeviceNames[id])
{ {
return vtkm::cont::make_DeviceAdapterId(id); return vtkm::cont::make_DeviceAdapterId(id);
} }
......
...@@ -209,7 +209,8 @@ public: ...@@ -209,7 +209,8 @@ public:
DeviceAdapterNameType GetDeviceName(DeviceAdapterId id) const; DeviceAdapterNameType GetDeviceName(DeviceAdapterId id) const;
/// Returns the id corresponding to the device adapter name. If @a name is /// Returns the id corresponding to the device adapter name. If @a name is
/// not recognized, DeviceAdapterTagUndefined is returned. /// not recognized, DeviceAdapterTagUndefined is returned. Queries for a
/// name are all case-insensitive.
VTKM_CONT_EXPORT VTKM_CONT_EXPORT
VTKM_CONT VTKM_CONT
DeviceAdapterId GetDeviceAdapterId(DeviceAdapterNameType name) const; DeviceAdapterId GetDeviceAdapterId(DeviceAdapterNameType name) const;
......
...@@ -64,7 +64,6 @@ struct DeviceAdapterId ...@@ -64,7 +64,6 @@ struct DeviceAdapterId
protected: protected:
friend DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id); friend DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id);
friend DeviceAdapterId make_DeviceAdapterIdFromName(const std::string& name);
constexpr explicit DeviceAdapterId(vtkm::Int8 id) constexpr explicit DeviceAdapterId(vtkm::Int8 id)
: Value(id) : Value(id)
...@@ -75,9 +74,19 @@ private: ...@@ -75,9 +74,19 @@ private:
vtkm::Int8 Value; vtkm::Int8 Value;
}; };
/// Construct a device adapter id from a runtime string
/// The string is case-insensitive. So CUDA will be selected with 'cuda', 'Cuda', or 'CUDA'.
VTKM_CONT_EXPORT VTKM_CONT_EXPORT
DeviceAdapterId make_DeviceAdapterId(const DeviceAdapterNameType& name); DeviceAdapterId make_DeviceAdapterId(const DeviceAdapterNameType& name);
/// Construct a device adapter id a vtkm::Int8.
/// The mapping of integer value to devices are:
///
/// DeviceAdapterTagSerial == 1
/// DeviceAdapterTagCuda == 2
/// DeviceAdapterTagTBB == 3
/// DeviceAdapterTagOpenMP == 4
///
inline DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id) inline DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id)
{ {
return DeviceAdapterId(id); return DeviceAdapterId(id);
......
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include <vtkm/cont/testing/Testing.h> #include <vtkm/cont/testing/Testing.h>
#include <cctype> //for tolower
namespace namespace
{ {
...@@ -53,12 +55,37 @@ void TestName(const std::string& name, Tag tag, vtkm::cont::DeviceAdapterId id) ...@@ -53,12 +55,37 @@ void TestName(const std::string& name, Tag tag, vtkm::cont::DeviceAdapterId id)
<< "\t" << tracker.GetDeviceName(id) << "\n" << "\t" << tracker.GetDeviceName(id) << "\n"
<< "\t" << tracker.GetDeviceName(tag) << "\n"; << "\t" << tracker.GetDeviceName(tag) << "\n";
#endif #endif
VTKM_TEST_ASSERT(id.GetName() == name, "Id::GetName() failed."); VTKM_TEST_ASSERT(id.GetName() == name, "Id::GetName() failed.");
VTKM_TEST_ASSERT(tag.GetName() == name, "Tag::GetName() failed."); VTKM_TEST_ASSERT(tag.GetName() == name, "Tag::GetName() failed.");
VTKM_TEST_ASSERT(vtkm::cont::make_DeviceAdapterId(id.GetValue()) == id,
"make_DeviceAdapterId(int8) failed");
VTKM_TEST_ASSERT(tracker.GetDeviceName(id) == name, "RTDeviceTracker::GetDeviceName(Id) failed."); VTKM_TEST_ASSERT(tracker.GetDeviceName(id) == name, "RTDeviceTracker::GetDeviceName(Id) failed.");
VTKM_TEST_ASSERT(tracker.GetDeviceName(tag) == name, VTKM_TEST_ASSERT(tracker.GetDeviceName(tag) == name,
"RTDeviceTracker::GetDeviceName(Tag) failed."); "RTDeviceTracker::GetDeviceName(Tag) failed.");
//check going from name to device id
auto lowerCaseFunc = [](char c) {
return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
};
auto upperCaseFunc = [](char c) {
return static_cast<char>(std::toupper(static_cast<unsigned char>(c)));
};
if (id.IsValueValid())
{ //only test make_DeviceAdapterId with valid device ids
VTKM_TEST_ASSERT(
vtkm::cont::make_DeviceAdapterId(name) == id, "make_DeviceAdapterId(", name, ") failed");
std::string casedName = name;
std::transform(casedName.begin(), casedName.end(), casedName.begin(), lowerCaseFunc);
VTKM_TEST_ASSERT(
vtkm::cont::make_DeviceAdapterId(casedName) == id, "make_DeviceAdapterId(", name, ") failed");
std::transform(casedName.begin(), casedName.end(), casedName.begin(), upperCaseFunc);
VTKM_TEST_ASSERT(
vtkm::cont::make_DeviceAdapterId(casedName) == id, "make_DeviceAdapterId(", name, ") failed");
}
} }
void TestNames() void TestNames()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment