Commit 1bb7dde9 authored by Utkarsh Ayachit's avatar Utkarsh Ayachit
Browse files

Update MultiBlock to use `diy` for block-based operations.

Updating MultiBlock to use `diy` for computing block summaries like
ranges, bounds etc. This makes it possible to MultiBlock to
work in distributed operations without explicit logic.
parent e9c7e561
......@@ -19,13 +19,142 @@
//============================================================================
#include <vtkm/StaticAssert.h>
#include <vtkm/cont/ArrayCopy.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/DataSet.h>
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
#include <vtkm/cont/DynamicArrayHandle.h>
#include <vtkm/cont/EnvironmentTracker.h>
#include <vtkm/cont/ErrorExecution.h>
#include <vtkm/cont/Field.h>
#include <vtkm/cont/MultiBlock.h>
#if defined(VTKM_ENABLE_MPI)
#include <diy/master.hpp>
namespace vtkm
{
namespace cont
{
namespace detail
{
template <typename PortalType>
VTKM_CONT std::vector<typename PortalType::ValueType> CopyArrayPortalToVector(
const PortalType& portal)
{
using ValueType = typename PortalType::ValueType;
std::vector<ValueType> result(portal.GetNumberOfValues());
vtkm::cont::ArrayPortalToIterators<PortalType> iterators(portal);
std::copy(iterators.GetBegin(), iterators.GetEnd(), result.begin());
return result;
}
}
}
}
namespace std
{
namespace detail
{
template <typename T, size_t ElementSize = sizeof(T)>
struct MPIPlus
{
MPIPlus()
{
this->OpPtr = std::shared_ptr<MPI_Op>(new MPI_Op(MPI_NO_OP), [](MPI_Op* ptr) {
MPI_Op_free(ptr);
delete ptr;
});
MPI_Op_create(
[](void* a, void* b, int* len, MPI_Datatype*) {
T* ba = reinterpret_cast<T*>(a);
T* bb = reinterpret_cast<T*>(b);
for (int cc = 0; cc < (*len) / ElementSize; ++cc)
{
bb[cc] = ba[cc] + bb[cc];
}
},
1,
this->OpPtr.get());
}
~MPIPlus() {}
operator MPI_Op() const { return *this->OpPtr.get(); }
private:
std::shared_ptr<MPI_Op> OpPtr;
};
} // std::detail
template <>
struct plus<vtkm::Bounds>
{
MPI_Op get_mpi_op() const { return this->Op; }
vtkm::Bounds operator()(const vtkm::Bounds& lhs, const vtkm::Bounds& rhs) const
{
return lhs + rhs;
}
private:
std::detail::MPIPlus<vtkm::Bounds> Op;
};
template <>
struct plus<vtkm::Range>
{
MPI_Op get_mpi_op() const { return this->Op; }
vtkm::Range operator()(const vtkm::Range& lhs, const vtkm::Range& rhs) const { return lhs + rhs; }
private:
std::detail::MPIPlus<vtkm::Range> Op;
};
}
namespace diy
{
namespace mpi
{
namespace detail
{
template <>
struct mpi_datatype<vtkm::Bounds>
{
static MPI_Datatype datatype() { return get_mpi_datatype<vtkm::Float64>(); }
static const void* address(const vtkm::Bounds& x) { return &x; }
static void* address(vtkm::Bounds& x) { return &x; }
static int count(const vtkm::Bounds&) { return 6; }
};
template <>
struct mpi_op<std::plus<vtkm::Bounds>>
{
static MPI_Op get(const std::plus<vtkm::Bounds>& op) { return op.get_mpi_op(); }
};
template <>
struct mpi_datatype<vtkm::Range>
{
static MPI_Datatype datatype() { return get_mpi_datatype<vtkm::Float64>(); }
static const void* address(const vtkm::Range& x) { return &x; }
static void* address(vtkm::Range& x) { return &x; }
static int count(const vtkm::Range&) { return 2; }
};
template <>
struct mpi_op<std::plus<vtkm::Range>>
{
static MPI_Op get(const std::plus<vtkm::Range>& op) { return op.get_mpi_op(); }
};
} // diy::mpi::detail
} // diy::mpi
} // diy
#endif
namespace vtkm
{
namespace cont
......@@ -86,6 +215,28 @@ vtkm::Id MultiBlock::GetNumberOfBlocks() const
return static_cast<vtkm::Id>(this->Blocks.size());
}
VTKM_CONT
vtkm::Id MultiBlock::GetGlobalNumberOfBlocks() const
{
#if defined(VTKM_ENABLE_MPI)
auto world = vtkm::cont::EnvironmentTracker::GetCommunicator();
const auto local_count = this->GetNumberOfBlocks();
diy::Master master(world, 1, -1);
int block_not_used = 1;
master.add(world.rank(), &block_not_used, new diy::Link());
// empty link since we're only using collectives.
master.foreach ([=](void*, const diy::Master::ProxyWithLink& cp) {
cp.all_reduce(local_count, std::plus<vtkm::Id>());
});
master.process_collectives();
vtkm::Id global_count = master.proxy(0).get<vtkm::Id>();
return global_count;
#else
return this->GetNumberOfBlocks();
#endif
}
VTKM_CONT
const vtkm::cont::DataSet& MultiBlock::GetBlock(vtkm::Id blockId) const
{
......@@ -158,6 +309,30 @@ VTKM_CONT vtkm::Bounds MultiBlock::GetBounds(vtkm::Id coordinate_system_index,
VTKM_IS_LIST_TAG(TypeList);
VTKM_IS_LIST_TAG(StorageList);
#if defined(VTKM_ENABLE_MPI)
auto world = vtkm::cont::EnvironmentTracker::GetCommunicator();
//const auto global_num_blocks = this->GetGlobalNumberOfBlocks();
const auto num_blocks = this->GetNumberOfBlocks();
diy::Master master(world, 1, -1);
for (vtkm::Id cc = 0; cc < num_blocks; ++cc)
{
int gid = cc * world.size() + world.rank();
master.add(gid, const_cast<vtkm::cont::DataSet*>(&this->Blocks[cc]), new diy::Link());
}
master.foreach ([&](const vtkm::cont::DataSet* block, const diy::Master::ProxyWithLink& cp) {
auto coords = block->GetCoordinateSystem(coordinate_system_index);
const vtkm::Bounds bounds = coords.GetBounds(TypeList(), StorageList());
cp.all_reduce(bounds, std::plus<vtkm::Bounds>());
});
master.process_collectives();
auto bounds = master.proxy(0).get<vtkm::Bounds>();
return bounds;
#else
const vtkm::Id index = coordinate_system_index;
const size_t num_blocks = this->Blocks.size();
......@@ -167,8 +342,8 @@ VTKM_CONT vtkm::Bounds MultiBlock::GetBounds(vtkm::Id coordinate_system_index,
vtkm::Bounds block_bounds = this->GetBlockBounds(i, index, TypeList(), StorageList());
bounds.Include(block_bounds);
}
return bounds;
#endif
}
VTKM_CONT
......@@ -267,6 +442,71 @@ template <typename TypeList, typename StorageList>
VTKM_CONT vtkm::cont::ArrayHandle<vtkm::Range>
MultiBlock::GetGlobalRange(const std::string& field_name, TypeList, StorageList) const
{
#if defined(VTKM_ENABLE_MPI)
auto world = vtkm::cont::EnvironmentTracker::GetCommunicator();
const auto num_blocks = this->GetNumberOfBlocks();
diy::Master master(world);
for (vtkm::Id cc = 0; cc < num_blocks; ++cc)
{
int gid = cc * world.size() + world.rank();
master.add(gid, const_cast<vtkm::cont::DataSet*>(&this->Blocks[cc]), new diy::Link());
}
// collect info about number of components in the field.
master.foreach ([&](const vtkm::cont::DataSet* dataset, const diy::Master::ProxyWithLink& cp) {
if (dataset->HasField(field_name))
{
auto field = dataset->GetField(field_name);
const vtkm::cont::ArrayHandle<vtkm::Range> range = field.GetRange(TypeList(), StorageList());
vtkm::Id components = range.GetPortalConstControl().GetNumberOfValues();
cp.all_reduce(components, diy::mpi::maximum<vtkm::Id>());
}
});
master.process_collectives();
const vtkm::Id components = master.size() ? master.proxy(0).read<vtkm::Id>() : 0;
// clear all collectives.
master.foreach ([&](const vtkm::cont::DataSet*, const diy::Master::ProxyWithLink& cp) {
cp.collectives()->clear();
});
master.foreach ([&](const vtkm::cont::DataSet* dataset, const diy::Master::ProxyWithLink& cp) {
if (dataset->HasField(field_name))
{
auto field = dataset->GetField(field_name);
const vtkm::cont::ArrayHandle<vtkm::Range> range = field.GetRange(TypeList(), StorageList());
const auto v_range =
vtkm::cont::detail::CopyArrayPortalToVector(range.GetPortalConstControl());
for (const vtkm::Range& r : v_range)
{
cp.all_reduce(r, std::plus<vtkm::Range>());
}
// if current block has less that the max number of components, just add invalid ranges for the rest.
for (vtkm::Id cc = static_cast<vtkm::Id>(v_range.size()); cc < components; ++cc)
{
cp.all_reduce(vtkm::Range(), std::plus<vtkm::Range>());
}
}
});
master.process_collectives();
std::vector<vtkm::Range> ranges(components);
// FIXME: is master.size() == 0 i.e. there are no blocks on the current rank,
// this method won't return valid range.
if (master.size() > 0)
{
for (vtkm::Id cc = 0; cc < components; ++cc)
{
ranges[cc] = master.proxy(0).get<vtkm::Range>();
}
}
vtkm::cont::ArrayHandle<vtkm::Range> tmprange = vtkm::cont::make_ArrayHandle(ranges);
vtkm::cont::ArrayHandle<vtkm::Range> range;
vtkm::cont::ArrayCopy(vtkm::cont::make_ArrayHandle(ranges), range);
return range;
#else
bool valid_field = true;
const size_t num_blocks = this->Blocks.size();
......@@ -324,6 +564,7 @@ MultiBlock::GetGlobalRange(const std::string& field_name, TypeList, StorageList)
}
return range;
#endif
}
VTKM_CONT
......
......@@ -64,6 +64,13 @@ public:
VTKM_CONT
vtkm::Id GetNumberOfBlocks() const;
/// Returns the number of blocks across all ranks. For non-MPI builds, this
/// will be same as `GetNumberOfBlocks()`.
/// This method is not thread-safe and may involve global communication across
/// all ranks in distributed environments with MPI.
VTKM_CONT
vtkm::Id GetGlobalNumberOfBlocks() const;
VTKM_CONT
const vtkm::cont::DataSet& GetBlock(vtkm::Id blockId) const;
......@@ -105,7 +112,11 @@ public:
vtkm::Id coordinate_system_index,
TypeList,
StorageList) const;
/// get the unified range of the same feild within all contained DataSet
//@{
/// Get the unified range of the same field within all contained DataSet.
/// These methods are not thread-safe and may involve global communication
/// across all ranks in distributed environments with MPI.
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::Range> GetGlobalRange(const std::string& field_name) const;
......@@ -128,6 +139,7 @@ public:
VTKM_CONT vtkm::cont::ArrayHandle<vtkm::Range> GetGlobalRange(const int& index,
TypeList,
StorageList) const;
//@}
VTKM_CONT
void PrintSummary(std::ostream& stream) const;
......
......@@ -27,6 +27,7 @@
#include <vtkm/cont/DataSet.h>
#include <vtkm/cont/DataSetFieldAdd.h>
#include <vtkm/cont/DynamicArrayHandle.h>
#include <vtkm/cont/EnvironmentTracker.h>
#include <vtkm/cont/Field.h>
#include <vtkm/cont/MultiBlock.h>
#include <vtkm/cont/serial/DeviceAdapterSerial.h>
......@@ -34,6 +35,10 @@
#include <vtkm/cont/testing/Testing.h>
#include <vtkm/exec/ConnectivityStructured.h>
#if defined(VTKM_ENABLE_MPI)
#include <diy/master.hpp>
#endif
void DataSet_Compare(vtkm::cont::DataSet& LeftDateSet, vtkm::cont::DataSet& RightDateSet);
static void MultiBlockTest()
{
......@@ -46,7 +51,14 @@ static void MultiBlockTest()
multiblock.AddBlock(TDset1);
multiblock.AddBlock(TDset2);
int procsize = 1;
#if defined(VTKM_ENABLE_MPI)
procsize = vtkm::cont::EnvironmentTracker::GetCommunicator().size();
#endif
VTKM_TEST_ASSERT(multiblock.GetNumberOfBlocks() == 2, "Incorrect number of blocks");
VTKM_TEST_ASSERT(multiblock.GetGlobalNumberOfBlocks() == 2 * procsize,
"Incorrect number of blocks");
vtkm::cont::DataSet TestDSet = multiblock.GetBlock(0);
VTKM_TEST_ASSERT(TDset1.GetNumberOfFields() == TestDSet.GetNumberOfFields(),
......@@ -155,7 +167,13 @@ void DataSet_Compare(vtkm::cont::DataSet& LeftDateSet, vtkm::cont::DataSet& Righ
return;
}
int UnitTestMultiBlock(int, char* [])
int UnitTestMultiBlock(int argc, char* argv[])
{
(void)argc;
(void)argv;
#if defined(VTKM_ENABLE_MPI)
diy::mpi::environment env(argc, argv);
vtkm::cont::EnvironmentTracker::SetCommunicator(diy::mpi::communicator(MPI_COMM_WORLD));
#endif
return vtkm::cont::testing::Testing::Run(MultiBlockTest);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment