DIY  3.0
data-parallel out-of-core C++ library
 All Classes Namespaces Functions Typedefs Groups Pages
collectives.hpp
1 #include <vector>
2 
3 #include "../constants.h" // for DIY_UNUSED.
4 #include "operations.hpp"
5 
6 namespace diy
7 {
8 namespace mpi
9 {
12 
13  template<class T, class Op>
14  struct Collectives
15  {
16  typedef detail::mpi_datatype<T> Datatype;
17 
18  static void broadcast(const communicator& comm, T& x, int root)
19  {
20 #ifndef DIY_NO_MPI
21  MPI_Bcast(Datatype::address(x),
22  Datatype::count(x),
23  Datatype::datatype(), root, comm);
24 #else
25  DIY_UNUSED(comm);
26  DIY_UNUSED(x);
27  DIY_UNUSED(root);
28 #endif
29  }
30 
31  static void broadcast(const communicator& comm, std::vector<T>& x, int root)
32  {
33 #ifndef DIY_NO_MPI
34  size_t sz = x.size();
35  int elem_size = Datatype::count(x[0]); // size of 1 vector element in units of mpi datatype
37 
38  if (comm.rank() != root)
39  x.resize(sz);
40 
41  MPI_Bcast(Datatype::address(x[0]),
42  elem_size * x.size(),
43  Datatype::datatype(), root, comm);
44 #else
45  DIY_UNUSED(comm);
46  DIY_UNUSED(x);
47  DIY_UNUSED(root);
48 #endif
49  }
50 
51  static request ibroadcast(const communicator& comm, T& x, int root)
52  {
53 #ifndef DIY_NO_MPI
54  request r;
55  MPI_Ibcast(Datatype::address(x),
56  Datatype::count(x),
57  Datatype::datatype(), root, comm, &r.r);
58  return r;
59 #else
60  DIY_UNUSED(comm);
61  DIY_UNUSED(x);
62  DIY_UNUSED(root);
63  DIY_UNSUPPORTED_MPI_CALL(MPI_Ibcast);
64 #endif
65  }
66 
67  static void gather(const communicator& comm, const T& in, std::vector<T>& out, int root)
68  {
69  out.resize(comm.size());
70 #ifndef DIY_NO_MPI
71  MPI_Gather(Datatype::address(const_cast<T&>(in)),
72  Datatype::count(in),
73  Datatype::datatype(),
74  Datatype::address(out[0]),
75  Datatype::count(in),
76  Datatype::datatype(),
77  root, comm);
78 #else
79  DIY_UNUSED(comm);
80  DIY_UNUSED(root);
81  out[0] = in;
82 #endif
83  }
84 
85  static void gather(const communicator& comm, const std::vector<T>& in, std::vector< std::vector<T> >& out, int root)
86  {
87 #ifndef DIY_NO_MPI
88  std::vector<int> counts(comm.size());
89  int elem_size = Datatype::count(in[0]); // size of 1 vector element in units of mpi datatype
90  Collectives<int,void*>::gather(comm, (int)(elem_size * in.size()), counts, root);
91 
92  std::vector<int> offsets(comm.size(), 0);
93  for (unsigned i = 1; i < offsets.size(); ++i)
94  offsets[i] = offsets[i-1] + counts[i-1];
95 
96  std::vector<T> buffer((offsets.back() + counts.back()) / elem_size);
97  MPI_Gatherv(Datatype::address(const_cast<T&>(in[0])),
98  elem_size * in.size(),
99  Datatype::datatype(),
100  Datatype::address(buffer[0]),
101  &counts[0],
102  &offsets[0],
103  Datatype::datatype(),
104  root, comm);
105 
106  out.resize(comm.size());
107  size_t cur = 0;
108  for (unsigned i = 0; i < (unsigned)comm.size(); ++i)
109  {
110  out[i].reserve(counts[i] / elem_size);
111  for (unsigned j = 0; j < (unsigned)(counts[i] / elem_size); ++j)
112  out[i].push_back(buffer[cur++]);
113  }
114 #else
115  DIY_UNUSED(comm);
116  DIY_UNUSED(root);
117  out.resize(1);
118  out[0] = in;
119 #endif
120  }
121 
122  static void gather(const communicator& comm, const T& in, int root)
123  {
124 #ifndef DIY_NO_MPI
125  MPI_Gather(Datatype::address(const_cast<T&>(in)),
126  Datatype::count(in),
127  Datatype::datatype(),
128  Datatype::address(const_cast<T&>(in)),
129  Datatype::count(in),
130  Datatype::datatype(),
131  root, comm);
132 #else
133  DIY_UNUSED(comm);
134  DIY_UNUSED(in);
135  DIY_UNUSED(root);
136  DIY_UNSUPPORTED_MPI_CALL("MPI_Gather");
137 #endif
138  }
139 
140  static void gather(const communicator& comm, const std::vector<T>& in, int root)
141  {
142 #ifndef DIY_NO_MPI
143  int elem_size = Datatype::count(in[0]); // size of 1 vector element in units of mpi datatype
144  Collectives<int,void*>::gather(comm, (int)(elem_size * in.size()), root);
145 
146  MPI_Gatherv(Datatype::address(const_cast<T&>(in[0])),
147  elem_size * in.size(),
148  Datatype::datatype(),
149  0, 0, 0,
150  Datatype::datatype(),
151  root, comm);
152 #else
153  DIY_UNUSED(comm);
154  DIY_UNUSED(in);
155  DIY_UNUSED(root);
156  DIY_UNSUPPORTED_MPI_CALL("MPI_Gatherv");
157 #endif
158  }
159 
160  static void all_gather(const communicator& comm, const T& in, std::vector<T>& out)
161  {
162  out.resize(comm.size());
163 #ifndef DIY_NO_MPI
164  MPI_Allgather(Datatype::address(const_cast<T&>(in)),
165  Datatype::count(in),
166  Datatype::datatype(),
167  Datatype::address(out[0]),
168  Datatype::count(in),
169  Datatype::datatype(),
170  comm);
171 #else
172  DIY_UNUSED(comm);
173  out[0] = in;
174 #endif
175  }
176 
177  static void all_gather(const communicator& comm, const std::vector<T>& in, std::vector< std::vector<T> >& out)
178  {
179 #ifndef DIY_NO_MPI
180  std::vector<int> counts(comm.size());
181  int elem_size = Datatype::count(in[0]); // size of 1 vector element in units of mpi datatype
182  Collectives<int,void*>::all_gather(comm, (int)(elem_size * in.size()), counts);
183 
184  std::vector<int> offsets(comm.size(), 0);
185  for (unsigned i = 1; i < offsets.size(); ++i)
186  offsets[i] = offsets[i-1] + counts[i-1];
187 
188  std::vector<T> buffer((offsets.back() + counts.back()) / elem_size);
189  MPI_Allgatherv(Datatype::address(const_cast<T&>(in[0])),
190  elem_size * in.size(),
191  Datatype::datatype(),
192  Datatype::address(buffer[0]),
193  &counts[0],
194  &offsets[0],
195  Datatype::datatype(),
196  comm);
197 
198  out.resize(comm.size());
199  size_t cur = 0;
200  for (int i = 0; i < comm.size(); ++i)
201  {
202  out[i].reserve(counts[i] / elem_size);
203  for (int j = 0; j < (int)(counts[i] / elem_size); ++j)
204  out[i].push_back(buffer[cur++]);
205  }
206 #else
207  DIY_UNUSED(comm);
208  out.resize(1);
209  out[0] = in;
210 #endif
211  }
212 
213  static void reduce(const communicator& comm, const T& in, T& out, int root, const Op&)
214  {
215 #ifndef DIY_NO_MPI
216  MPI_Reduce(Datatype::address(const_cast<T&>(in)),
217  Datatype::address(out),
218  Datatype::count(in),
219  Datatype::datatype(),
220  detail::mpi_op<Op>::get(),
221  root, comm);
222 #else
223  DIY_UNUSED(comm);
224  DIY_UNUSED(root);
225  out = in;
226 #endif
227  }
228 
229  static void reduce(const communicator& comm, const T& in, int root, const Op&)
230  {
231 #ifndef DIY_NO_MPI
232  MPI_Reduce(Datatype::address(const_cast<T&>(in)),
233  Datatype::address(const_cast<T&>(in)),
234  Datatype::count(in),
235  Datatype::datatype(),
236  detail::mpi_op<Op>::get(),
237  root, comm);
238 #else
239  DIY_UNUSED(comm);
240  DIY_UNUSED(in);
241  DIY_UNUSED(root);
242  DIY_UNSUPPORTED_MPI_CALL("MPI_Reduce");
243 #endif
244  }
245 
246  static void all_reduce(const communicator& comm, const T& in, T& out, const Op&)
247  {
248 #ifndef DIY_NO_MPI
249  MPI_Allreduce(Datatype::address(const_cast<T&>(in)),
250  Datatype::address(out),
251  Datatype::count(in),
252  Datatype::datatype(),
253  detail::mpi_op<Op>::get(),
254  comm);
255 #else
256  DIY_UNUSED(comm);
257  out = in;
258 #endif
259  }
260 
261  static void all_reduce(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, const Op&)
262  {
263 #ifndef DIY_NO_MPI
264  out.resize(in.size());
265  int elem_size = Datatype::count(in[0]); // size of 1 vector element in units of mpi datatype
266  MPI_Allreduce(Datatype::address(const_cast<T&>(in[0])),
267  Datatype::address(out[0]),
268  elem_size * in.size(),
269  Datatype::datatype(),
270  detail::mpi_op<Op>::get(),
271  comm);
272 #else
273  DIY_UNUSED(comm);
274  out = in;
275 #endif
276  }
277 
278  static void scan(const communicator& comm, const T& in, T& out, const Op&)
279  {
280 #ifndef DIY_NO_MPI
281  MPI_Scan(Datatype::address(const_cast<T&>(in)),
282  Datatype::address(out),
283  Datatype::count(in),
284  Datatype::datatype(),
285  detail::mpi_op<Op>::get(),
286  comm);
287 #else
288  DIY_UNUSED(comm);
289  out = in;
290 #endif
291  }
292 
293  static void all_to_all(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, int n = 1)
294  {
295 #ifndef DIY_NO_MPI
296  int elem_size = Datatype::count(in[0]); // size of 1 vector element in units of mpi datatype
297  // NB: this will fail if T is a vector
298  MPI_Alltoall(Datatype::address(const_cast<T&>(in[0])),
299  elem_size * n,
300  Datatype::datatype(),
301  Datatype::address(out[0]),
302  elem_size * n,
303  Datatype::datatype(),
304  comm);
305 #else
306  DIY_UNUSED(comm);
307  DIY_UNUSED(n);
308  out = in;
309 #endif
310  }
311  };
312 
314  template<class T>
315  void broadcast(const communicator& comm, T& x, int root)
316  {
317  Collectives<T,void*>::broadcast(comm, x, root);
318  }
319 
321  template<class T>
322  void broadcast(const communicator& comm, std::vector<T>& x, int root)
323  {
324  Collectives<T,void*>::broadcast(comm, x, root);
325  }
326 
328  template<class T>
329  request ibroadcast(const communicator& comm, T& x, int root)
330  {
331  return Collectives<T,void*>::ibroadcast(comm, x, root);
332  }
333 
337  template<class T>
338  void gather(const communicator& comm, const T& in, std::vector<T>& out, int root)
339  {
340  Collectives<T,void*>::gather(comm, in, out, root);
341  }
342 
344  template<class T>
345  void gather(const communicator& comm, const std::vector<T>& in, std::vector< std::vector<T> >& out, int root)
346  {
347  Collectives<T,void*>::gather(comm, in, out, root);
348  }
349 
351  template<class T>
352  void gather(const communicator& comm, const T& in, int root)
353  {
354  Collectives<T,void*>::gather(comm, in, root);
355  }
356 
358  template<class T>
359  void gather(const communicator& comm, const std::vector<T>& in, int root)
360  {
361  Collectives<T,void*>::gather(comm, in, root);
362  }
363 
367  template<class T>
368  void all_gather(const communicator& comm, const T& in, std::vector<T>& out)
369  {
370  Collectives<T,void*>::all_gather(comm, in, out);
371  }
372 
374  template<class T>
375  void all_gather(const communicator& comm, const std::vector<T>& in, std::vector< std::vector<T> >& out)
376  {
377  Collectives<T,void*>::all_gather(comm, in, out);
378  }
379 
381  template<class T, class Op>
382  void reduce(const communicator& comm, const T& in, T& out, int root, const Op& op)
383  {
384  Collectives<T, Op>::reduce(comm, in, out, root, op);
385  }
386 
388  template<class T, class Op>
389  void reduce(const communicator& comm, const T& in, int root, const Op& op)
390  {
391  Collectives<T, Op>::reduce(comm, in, root, op);
392  }
393 
395  template<class T, class Op>
396  void all_reduce(const communicator& comm, const T& in, T& out, const Op& op)
397  {
398  Collectives<T, Op>::all_reduce(comm, in, out, op);
399  }
400 
402  template<class T, class Op>
403  void all_reduce(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, const Op& op)
404  {
405  Collectives<T, Op>::all_reduce(comm, in, out, op);
406  }
407 
409  template<class T, class Op>
410  void scan(const communicator& comm, const T& in, T& out, const Op& op)
411  {
412  Collectives<T, Op>::scan(comm, in, out, op);
413  }
414 
416  template<class T>
417  void all_to_all(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, int n = 1)
418  {
419  Collectives<T, void*>::all_to_all(comm, in, out, n);
420  }
421 
423 }
424 }
request ibroadcast(const communicator &comm, T &x, int root)
iBroadcast to all processes in comm.
Definition: collectives.hpp:329
void all_reduce(const communicator &comm, const std::vector< T > &in, std::vector< T > &out, const Op &op)
Same as above, but for vectors.
Definition: collectives.hpp:403
void all_to_all(const communicator &comm, const std::vector< T > &in, std::vector< T > &out, int n=1)
all_to_all
Definition: collectives.hpp:417
void reduce(const communicator &comm, const T &in, T &out, int root, const Op &op)
reduce
Definition: collectives.hpp:382
void broadcast(const communicator &comm, std::vector< T > &x, int root)
Broadcast for vectors.
Definition: collectives.hpp:322
void scan(const communicator &comm, const T &in, T &out, const Op &op)
scan
Definition: collectives.hpp:410
void in(const RegularLink< Bounds > &link, const Point &p, OutIter out, const Bounds &domain)
Finds the neighbor(s) containing the target point.
Definition: pick.hpp:102
Simple wrapper around MPI_Comm.
Definition: communicator.hpp:8
void all_gather(const communicator &comm, const T &in, std::vector< T > &out)
all_gather from all processes in comm. out is resized to comm.size() and filled with elements from th...
Definition: collectives.hpp:368
void reduce(const communicator &comm, const T &in, int root, const Op &op)
Simplified version (without out) for use on non-root processes.
Definition: collectives.hpp:389
void all_reduce(const communicator &comm, const T &in, T &out, const Op &op)
all_reduce
Definition: collectives.hpp:396
Definition: request.hpp:5
void all_gather(const communicator &comm, const std::vector< T > &in, std::vector< std::vector< T > > &out)
Same as above, but for vectors.
Definition: collectives.hpp:375
void broadcast(const communicator &comm, T &x, int root)
Broadcast to all processes in comm.
Definition: collectives.hpp:315
void gather(const communicator &comm, const T &in, std::vector< T > &out, int root)
Gather from all processes in comm. On root process, out is resized to comm.size() and filled with ele...
Definition: collectives.hpp:338
void gather(const communicator &comm, const std::vector< T > &in, int root)
Simplified version (without out) for use on non-root processes.
Definition: collectives.hpp:359