1 #ifndef DIY_COLLECTIVES_HPP
2 #define DIY_COLLECTIVES_HPP
10 virtual void init() =0;
11 virtual void update(
const CollectiveOp& other) =0;
12 virtual void global(
const mpi::communicator& comm) =0;
13 virtual void copy_from(
const CollectiveOp& other) =0;
14 virtual void result_out(
void* dest)
const =0;
15 virtual ~CollectiveOp() {}
18 template<
class T,
class Op>
19 struct AllReduceOp:
public CollectiveOp
21 AllReduceOp(
const T& x, Op op):
24 void init() { out_ = in_; }
25 void update(
const CollectiveOp& other) { out_ = op_(out_, static_cast<const AllReduceOp&>(other).in_); }
26 void global(
const mpi::communicator& comm) { T res;
mpi::all_reduce(comm, out_, res, op_); out_ = res; }
27 void copy_from(
const CollectiveOp& other) { out_ =
static_cast<const AllReduceOp&
>(other).out_; }
28 void result_out(
void* dest)
const { *
reinterpret_cast<T*
>(dest) = out_; }
36 struct Scratch:
public CollectiveOp
42 void update(
const CollectiveOp&) {}
43 void global(
const mpi::communicator&) {}
44 void copy_from(
const CollectiveOp&) {}
45 void result_out(
void* dest)
const { *
reinterpret_cast<T*
>(dest) = x_; }
void all_reduce(const communicator &comm, const T &in, T &out, const Op &op)
all_reduce
Definition: collectives.hpp:396