Commit e621b6ba authored by Allison Vacanti's avatar Allison Vacanti
Browse files

Generalize the TBB radix sort implementation.

The core algorithm will be shared by OpenMP.
parent d6027843
......@@ -37,6 +37,9 @@ set(headers
DynamicTransform.h
FunctorsGeneral.h
IteratorFromArrayPortal.h
KXSort.h
ParallelRadixSort.h
ParallelRadixSortInterface.h
SimplePolymorphicContainer.h
StorageError.h
VirtualObjectTransfer.h
......
//=============================================================================
//
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//
// Copyright 2018 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
// Copyright 2018 UT-Battelle, LLC.
// Copyright 2018 Los Alamos National Security.
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
// Laboratory (LANL), the U.S. Government retains certain rights in
// this software.
//
//=============================================================================
/* The MIT License
Copyright (c) 2016 Dinghua Li <voutcn@gmail.com>
......
This diff is collapsed.
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//
// Copyright 2017 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
// Copyright 2017 UT-Battelle, LLC.
// Copyright 2017 Los Alamos National Security.
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
// Laboratory (LANL), the U.S. Government retains certain rights in
// this software.
//============================================================================
#ifndef vtk_m_cont_internal_ParallelRadixSortInterface_h
#define vtk_m_cont_internal_ParallelRadixSortInterface_h
#include <vtkm/BinaryPredicates.h>
#include <vtkm/cont/ArrayHandle.h>
#include <functional>
#include <type_traits>
namespace vtkm
{
namespace cont
{
namespace internal
{
namespace radix
{
const size_t MIN_BYTES_FOR_PARALLEL = 400000;
const size_t BYTES_FOR_MAX_PARALLELISM = 4000000;
struct RadixSortTag
{
};
struct PSortTag
{
};
// Detect supported functors for radix sort:
template <typename T>
struct is_valid_compare_type : std::integral_constant<bool, false>
{
};
template <typename T>
struct is_valid_compare_type<std::less<T>> : std::integral_constant<bool, true>
{
};
template <typename T>
struct is_valid_compare_type<std::greater<T>> : std::integral_constant<bool, true>
{
};
template <>
struct is_valid_compare_type<vtkm::SortLess> : std::integral_constant<bool, true>
{
};
template <>
struct is_valid_compare_type<vtkm::SortGreater> : std::integral_constant<bool, true>
{
};
// Convert vtkm::Sort[Less|Greater] to the std:: equivalents:
template <typename BComp, typename T>
BComp&& get_std_compare(BComp&& b, T&&)
{
return std::forward<BComp>(b);
}
template <typename T>
std::less<T> get_std_compare(vtkm::SortLess, T&&)
{
return std::less<T>{};
}
template <typename T>
std::greater<T> get_std_compare(vtkm::SortGreater, T&&)
{
return std::greater<T>{};
}
// Determine if radix sort can be used for a given ValueType, StorageType, and
// comparison functor.
template <typename T, typename StorageTag, typename BinaryCompare>
struct sort_tag_type
{
using type = PSortTag;
};
template <typename T, typename BinaryCompare>
struct sort_tag_type<T, vtkm::cont::StorageTagBasic, BinaryCompare>
{
using PrimT = std::is_arithmetic<T>;
using LongDT = std::is_same<T, long double>;
using BComp = is_valid_compare_type<BinaryCompare>;
using type = typename std::conditional<PrimT::value && BComp::value && !LongDT::value,
RadixSortTag,
PSortTag>::type;
};
template <typename KeyType,
typename ValueType,
typename KeyStorageTagType,
typename ValueStorageTagType,
class BinaryCompare>
struct sortbykey_tag_type
{
using type = PSortTag;
};
template <typename KeyType, typename ValueType, class BinaryCompare>
struct sortbykey_tag_type<KeyType,
ValueType,
vtkm::cont::StorageTagBasic,
vtkm::cont::StorageTagBasic,
BinaryCompare>
{
using PrimKey = std::is_arithmetic<KeyType>;
using PrimValue = std::is_arithmetic<ValueType>;
using LongDKey = std::is_same<KeyType, long double>;
using BComp = is_valid_compare_type<BinaryCompare>;
using type = typename std::conditional<PrimKey::value && PrimValue::value && BComp::value &&
!LongDKey::value,
RadixSortTag,
PSortTag>::type;
};
#define VTKM_INTERNAL_RADIX_SORT_DECLARE(key_type) \
VTKM_CONT_EXPORT void parallel_radix_sort( \
key_type* data, size_t num_elems, const std::greater<key_type>& comp); \
VTKM_CONT_EXPORT void parallel_radix_sort( \
key_type* data, size_t num_elems, const std::less<key_type>& comp); \
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::greater<key_type>& comp); \
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::less<key_type>& comp);
// Generate radix sort interfaces for key and key value sorts.
#define VTKM_DECLARE_RADIX_SORT() \
VTKM_INTERNAL_RADIX_SORT_DECLARE(short int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned short int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(long int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned long int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(long long int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned long long int) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned char) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(signed char) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(char) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(char16_t) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(char32_t) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(wchar_t) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(float) \
VTKM_INTERNAL_RADIX_SORT_DECLARE(double)
}
}
}
} // end vtkm::cont::internal::radix
#endif // vtk_m_cont_internal_ParallelRadixSortInterface_h
......@@ -251,21 +251,21 @@ public:
{
//this is required to get sort to work with zip handles
std::less<T> lessOp;
vtkm::cont::tbb::internal::parallel_sort(values, lessOp);
vtkm::cont::tbb::sort::parallel_sort(values, lessOp);
}
template <typename T, class Container, class BinaryCompare>
VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Container>& values,
BinaryCompare binary_compare)
{
vtkm::cont::tbb::internal::parallel_sort(values, binary_compare);
vtkm::cont::tbb::sort::parallel_sort(values, binary_compare);
}
template <typename T, typename U, class StorageT, class StorageU>
VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<U, StorageU>& values)
{
vtkm::cont::tbb::internal::parallel_sort_bykey(keys, values, std::less<T>());
vtkm::cont::tbb::sort::parallel_sort_bykey(keys, values, std::less<T>());
}
template <typename T, typename U, class StorageT, class StorageU, class BinaryCompare>
......@@ -273,7 +273,7 @@ public:
vtkm::cont::ArrayHandle<U, StorageU>& values,
BinaryCompare binary_compare)
{
vtkm::cont::tbb::internal::parallel_sort_bykey(keys, values, binary_compare);
vtkm::cont::tbb::sort::parallel_sort_bykey(keys, values, binary_compare);
}
template <typename T, class Storage>
......
This diff is collapsed.
......@@ -24,6 +24,7 @@
#include <vtkm/BinaryPredicates.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleZip.h>
#include <vtkm/cont/internal/ParallelRadixSortInterface.h>
#include <vtkm/cont/tbb/internal/ArrayManagerExecutionTBB.h>
#include <vtkm/cont/tbb/internal/DeviceAdapterTagTBB.h>
......@@ -38,128 +39,27 @@ namespace cont
{
namespace tbb
{
namespace internal
namespace sort
{
struct RadixSortTag
{
};
struct PSortTag
{
};
template <typename T>
struct is_valid_compare_type : std::integral_constant<bool, false>
{
};
template <typename T>
struct is_valid_compare_type<std::less<T>> : std::integral_constant<bool, true>
{
};
template <typename T>
struct is_valid_compare_type<std::greater<T>> : std::integral_constant<bool, true>
{
};
template <>
struct is_valid_compare_type<vtkm::SortLess> : std::integral_constant<bool, true>
{
};
template <>
struct is_valid_compare_type<vtkm::SortGreater> : std::integral_constant<bool, true>
{
};
template <typename BComp, typename T>
BComp&& get_std_compare(BComp&& b, T&&)
{
return std::forward<BComp>(b);
}
template <typename T>
std::less<T> get_std_compare(vtkm::SortLess, T&&)
{
return std::less<T>{};
}
template <typename T>
std::greater<T> get_std_compare(vtkm::SortGreater, T&&)
{
return std::greater<T>{};
}
// Declare the compiled radix sort specializations:
VTKM_DECLARE_RADIX_SORT()
template <typename T, typename StorageTag, typename BinaryCompare>
struct sort_tag_type
{
using type = PSortTag;
};
template <typename T, typename BinaryCompare>
struct sort_tag_type<T, vtkm::cont::StorageTagBasic, BinaryCompare>
{
using PrimT = std::is_arithmetic<T>;
using LongDT = std::is_same<T, long double>;
using BComp = is_valid_compare_type<BinaryCompare>;
using type = typename std::conditional<PrimT::value && BComp::value && !LongDT::value,
RadixSortTag,
PSortTag>::type;
};
template <typename T, typename U, typename StorageTagT, typename StorageTagU, class BinaryCompare>
struct sortbykey_tag_type
{
using type = PSortTag;
};
template <typename T, typename U, typename BinaryCompare>
struct sortbykey_tag_type<T,
U,
vtkm::cont::StorageTagBasic,
vtkm::cont::StorageTagBasic,
BinaryCompare>
{
using PrimT = std::is_arithmetic<T>;
using PrimU = std::is_arithmetic<U>;
using LongDT = std::is_same<T, long double>;
using BComp = is_valid_compare_type<BinaryCompare>;
using type =
typename std::conditional<PrimT::value && PrimU::value && BComp::value && !LongDT::value,
RadixSortTag,
PSortTag>::type;
};
#define VTKM_TBB_SORT_EXPORT(key_type) \
VTKM_CONT_EXPORT void parallel_radix_sort( \
key_type* data, size_t num_elems, const std::greater<key_type>& comp); \
VTKM_CONT_EXPORT void parallel_radix_sort( \
key_type* data, size_t num_elems, const std::less<key_type>& comp); \
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::greater<key_type>& comp); \
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::less<key_type>& comp);
// Generate radix sort interfaces for key and key value sorts.
VTKM_TBB_SORT_EXPORT(short int);
VTKM_TBB_SORT_EXPORT(unsigned short int);
VTKM_TBB_SORT_EXPORT(int);
VTKM_TBB_SORT_EXPORT(unsigned int);
VTKM_TBB_SORT_EXPORT(long int);
VTKM_TBB_SORT_EXPORT(unsigned long int);
VTKM_TBB_SORT_EXPORT(long long int);
VTKM_TBB_SORT_EXPORT(unsigned long long int);
VTKM_TBB_SORT_EXPORT(unsigned char);
VTKM_TBB_SORT_EXPORT(signed char);
VTKM_TBB_SORT_EXPORT(char);
VTKM_TBB_SORT_EXPORT(char16_t);
VTKM_TBB_SORT_EXPORT(char32_t);
VTKM_TBB_SORT_EXPORT(wchar_t);
VTKM_TBB_SORT_EXPORT(float);
VTKM_TBB_SORT_EXPORT(double);
#undef VTKM_TBB_SORT_EXPORT
// Forward declare entry points (See stack overflow discussion 7255281 --
// templated overloads of template functions are not specialization, and will
// be resolved during the first phase of two part lookup).
template <typename T, typename Container, class BinaryCompare>
void parallel_sort(vtkm::cont::ArrayHandle<T, Container>& values, BinaryCompare binary_compare)
{
using SortAlgorithmTag = typename sort_tag_type<T, Container, BinaryCompare>::type;
parallel_sort(values, binary_compare, SortAlgorithmTag{});
}
void parallel_sort(vtkm::cont::ArrayHandle<T, Container>&, BinaryCompare);
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>&,
vtkm::cont::ArrayHandle<U, StorageU>&,
BinaryCompare);
// Quicksort values:
template <typename HandleType, class BinaryCompare>
void parallel_sort(HandleType& values, BinaryCompare binary_compare, PSortTag)
void parallel_sort(HandleType& values,
BinaryCompare binary_compare,
vtkm::cont::internal::radix::PSortTag)
{
auto arrayPortal = values.PrepareForInPlace(vtkm::cont::DeviceAdapterTagTBB());
......@@ -169,31 +69,37 @@ void parallel_sort(HandleType& values, BinaryCompare binary_compare, PSortTag)
internal::WrappedBinaryOperator<bool, BinaryCompare> wrappedCompare(binary_compare);
::tbb::parallel_sort(iterators.GetBegin(), iterators.GetEnd(), wrappedCompare);
}
// Radix sort values:
template <typename T, typename StorageT, class BinaryCompare>
void parallel_sort(vtkm::cont::ArrayHandle<T, StorageT>& values,
BinaryCompare binary_compare,
RadixSortTag)
vtkm::cont::internal::radix::RadixSortTag)
{
using namespace vtkm::cont::internal::radix;
auto c = get_std_compare(binary_compare, T{});
parallel_radix_sort(
values.GetStorage().GetArray(), static_cast<std::size_t>(values.GetNumberOfValues()), c);
}
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<U, StorageU>& values,
BinaryCompare binary_compare)
// Value sort -- static switch between quicksort and radix sort
template <typename T, typename Container, class BinaryCompare>
void parallel_sort(vtkm::cont::ArrayHandle<T, Container>& values, BinaryCompare binary_compare)
{
using SortAlgorithmTag =
typename sortbykey_tag_type<T, U, StorageT, StorageU, BinaryCompare>::type;
parallel_sort_bykey(keys, values, binary_compare, SortAlgorithmTag{});
using namespace vtkm::cont::internal::radix;
using SortAlgorithmTag = typename sort_tag_type<T, Container, BinaryCompare>::type;
parallel_sort(values, binary_compare, SortAlgorithmTag{});
}
// Quicksort by key
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<U, StorageU>& values,
BinaryCompare binary_compare,
PSortTag)
vtkm::cont::internal::radix::PSortTag)
{
using namespace vtkm::cont::internal::radix;
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>;
constexpr bool larger_than_64bits = sizeof(U) > sizeof(vtkm::Int64);
if (larger_than_64bits)
......@@ -243,23 +149,28 @@ void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
zipHandle, vtkm::cont::internal::KeyCompare<T, U, BinaryCompare>(binary_compare), PSortTag{});
}
}
// Radix sort by key -- Specialize for vtkm::Id values:
template <typename T, typename StorageT, typename StorageU, class BinaryCompare>
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<vtkm::Id, StorageU>& values,
BinaryCompare binary_compare,
RadixSortTag)
vtkm::cont::internal::radix::RadixSortTag)
{
using namespace vtkm::cont::internal::radix;
auto c = get_std_compare(binary_compare, T{});
parallel_radix_sort_key_values(keys.GetStorage().GetArray(),
values.GetStorage().GetArray(),
static_cast<std::size_t>(keys.GetNumberOfValues()),
c);
}
// Radix sort by key -- Generic impl:
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<U, StorageU>& values,
BinaryCompare binary_compare,
RadixSortTag)
vtkm::cont::internal::radix::RadixSortTag)
{
using KeyType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
using ValueType = vtkm::cont::ArrayHandle<U, vtkm::cont::StorageTagBasic>;
......@@ -287,7 +198,7 @@ void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, indexArray);
parallel_sort(zipHandle,
vtkm::cont::internal::KeyCompare<T, vtkm::Id, BinaryCompare>(binary_compare),
PSortTag{});
vtkm::cont::internal::radix::PSortTag{});
}
tbb::ScatterPortal(values.PrepareForInput(vtkm::cont::DeviceAdapterTagTBB()),
......@@ -301,9 +212,21 @@ void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
tbb::CopyPortals(inputPortal, outputPortal, 0, 0, valuesScattered.GetNumberOfValues());
}
}
// Sort by key -- static switch between radix and quick sort:
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<U, StorageU>& values,
BinaryCompare binary_compare)
{
using namespace vtkm::cont::internal::radix;
using SortAlgorithmTag =
typename sortbykey_tag_type<T, U, StorageT, StorageU, BinaryCompare>::type;
parallel_sort_bykey(keys, values, binary_compare, SortAlgorithmTag{});
}
}
}
}
} // end namespace vtkm::cont::tbb::sort
#endif // vtk_m_cont_tbb_internal_ParallelSort_h
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment