DIY  3.0
data-parallel out-of-core C++ library
 All Classes Namespaces Functions Typedefs Groups Pages
all-to-all.hpp
1 #ifndef DIY_DETAIL_ALL_TO_ALL_HPP
2 #define DIY_DETAIL_ALL_TO_ALL_HPP
3 
4 #include "../block_traits.hpp"
5 
6 namespace diy
7 {
8 
9 namespace detail
10 {
11  template<class Op>
12  struct AllToAllReduce
13  {
14  using Block = typename block_traits<Op>::type;
15 
16  AllToAllReduce(const Op& op_, const Assigner& assigner):
17  op(op_)
18  {
19  for (int gid = 0; gid < assigner.nblocks(); ++gid)
20  {
21  BlockID nbr = { gid, assigner.rank(gid) };
22  all_neighbors_link.add_neighbor(nbr);
23  }
24  }
25 
26  void operator()(Block* b, const ReduceProxy& srp, const RegularSwapPartners& partners) const
27  {
28  int k_in = srp.in_link().size();
29  int k_out = srp.out_link().size();
30 
31  if (k_in == 0 && k_out == 0) // special case of a single block
32  {
33  ReduceProxy all_srp_out(srp, srp.block(), 0, srp.assigner(), empty_link, all_neighbors_link);
34  ReduceProxy all_srp_in (srp, srp.block(), 1, srp.assigner(), all_neighbors_link, empty_link);
35 
36  op(b, all_srp_out);
37  MemoryBuffer& in_queue = all_srp_in.incoming(all_srp_in.in_link().target(0).gid);
38  in_queue.swap(all_srp_out.outgoing(all_srp_out.out_link().target(0)));
39  in_queue.reset();
40 
41  op(b, all_srp_in);
42  return;
43  }
44 
45  if (k_in == 0) // initial round
46  {
47  ReduceProxy all_srp(srp, srp.block(), 0, srp.assigner(), empty_link, all_neighbors_link);
48  op(b, all_srp);
49 
50  Master::OutgoingQueues all_queues;
51  all_queues.swap(*all_srp.outgoing()); // clears out the queues and stores them locally
52 
53  // enqueue outgoing
54  int group = all_srp.out_link().size() / k_out;
55  for (int i = 0; i < k_out; ++i)
56  {
57  std::pair<int,int> range(i*group, (i+1)*group);
58  srp.enqueue(srp.out_link().target(i), range);
59  for (int j = i*group; j < (i+1)*group; ++j)
60  {
61  int from = srp.gid();
62  int to = all_srp.out_link().target(j).gid;
63  srp.enqueue(srp.out_link().target(i), std::make_pair(from, to));
64  srp.enqueue(srp.out_link().target(i), all_queues[all_srp.out_link().target(j)]);
65  }
66  }
67  } else if (k_out == 0) // final round
68  {
69  // dequeue incoming + reorder into the correct order
70  ReduceProxy all_srp(srp, srp.block(), 1, srp.assigner(), all_neighbors_link, empty_link);
71 
72  Master::IncomingQueues all_incoming;
73  all_incoming.swap(*srp.incoming());
74 
75  std::pair<int, int> range; // all the ranges should be the same
76  for (int i = 0; i < k_in; ++i)
77  {
78  int gid_in = srp.in_link().target(i).gid;
79  MemoryBuffer& in = all_incoming[gid_in];
80  load(in, range);
81  while(in)
82  {
83  std::pair<int, int> from_to;
84  load(in, from_to);
85  load(in, all_srp.incoming(from_to.first));
86  all_srp.incoming(from_to.first).reset();
87  }
88  }
89 
90  op(b, all_srp);
91  } else // intermediate round: reshuffle queues
92  {
93  // add up buffer sizes
94  std::vector<size_t> sizes_out(k_out, sizeof(std::pair<int,int>));
95  std::pair<int, int> range; // all the ranges should be the same
96  for (int i = 0; i < k_in; ++i)
97  {
98  MemoryBuffer& in = srp.incoming(srp.in_link().target(i).gid);
99 
100  load(in, range);
101  int group = (range.second - range.first)/k_out;
102 
103  std::pair<int, int> from_to;
104  size_t s;
105  while(in)
106  {
107  diy::load(in, from_to);
108  diy::load(in, s);
109 
110  int j = (from_to.second - range.first) / group;
111  sizes_out[j] += s + sizeof(size_t) + sizeof(std::pair<int,int>);
112  in.skip(s);
113  }
114  in.reset();
115  }
116 
117  // reserve outgoing buffers of correct size
118  int group = (range.second - range.first)/k_out;
119  for (int i = 0; i < k_out; ++i)
120  {
121  MemoryBuffer& out = srp.outgoing(srp.out_link().target(i));
122  out.reserve(sizes_out[i]);
123 
124  std::pair<int, int> out_range;
125  out_range.first = range.first + group*i;
126  out_range.second = range.first + group*(i+1);
127  save(out, out_range);
128  }
129 
130  // re-direct the queues
131  for (int i = 0; i < k_in; ++i)
132  {
133  MemoryBuffer& in = srp.incoming(srp.in_link().target(i).gid);
134 
135  load(in, range);
136 
137  std::pair<int, int> from_to;
138  while(in)
139  {
140  load(in, from_to);
141  int j = (from_to.second - range.first) / group;
142 
143  MemoryBuffer& out = srp.outgoing(srp.out_link().target(j));
144  save(out, from_to);
145  MemoryBuffer::copy(in, out);
146  }
147  }
148  }
149  }
150 
151  const Op& op;
152  Link all_neighbors_link, empty_link;
153  };
154 
155  struct SkipIntermediate
156  {
157  SkipIntermediate(size_t rounds_):
158  rounds(rounds_) {}
159 
160  bool operator()(int round, int, const Master&) const { if (round == 0 || round == (int) rounds) return false; return true; }
161 
162  size_t rounds;
163  };
164 }
165 
166 }
167 
168 #endif
void load(BinaryBuffer &bb, T &x)
Loads x from bb by calling diy::Serialization<T>::load(bb,x).
Definition: serialization.hpp:106
static void copy(MemoryBuffer &from, MemoryBuffer &to)
copy a memory buffer from one buffer to another, bypassing making a temporary copy first ...
Definition: serialization.hpp:450
void save(BinaryBuffer &bb, const T &x)
Saves x to bb by calling diy::Serialization<T>::save(bb,x).
Definition: serialization.hpp:102
void in(const RegularLink< Bounds > &link, const Point &p, OutIter out, const Bounds &domain)
Finds the neighbor(s) containing the target point.
Definition: pick.hpp:102