1 #ifndef DIY_DETAIL_ALGORITHMS_KDTREE_HPP
2 #define DIY_DETAIL_ALGORITHMS_KDTREE_HPP
6 #include "../../partners/all-reduce.hpp"
7 #include "../../log.hpp"
14 struct KDTreePartners;
16 template<
class Block,
class Po
int>
17 struct KDTreePartition
22 typedef std::vector<size_t> Histogram;
24 KDTreePartition(
int dim,
25 std::vector<Point> Block::* points,
27 dim_(dim), points_(points), bins_(bins) {}
29 void operator()(Block* b,
const diy::ReduceProxy& srp,
const KDTreePartners& partners)
const;
31 int divide_gid(
int gid,
bool lower,
int round,
int rounds)
const;
32 void update_links(Block* b,
const diy::ReduceProxy& srp,
int dim,
int round,
int rounds,
bool wrap,
const Bounds& domain)
const;
33 void split_to_neighbors(Block* b,
const diy::ReduceProxy& srp,
int dim)
const;
35 find_wrap(
const Bounds& bounds,
const Bounds& nbr_bounds,
const Bounds& domain)
const;
37 void compute_local_histogram(Block* b,
const diy::ReduceProxy& srp,
int dim)
const;
38 void add_histogram(Block* b,
const diy::ReduceProxy& srp, Histogram& histogram)
const;
39 void receive_histogram(Block* b,
const diy::ReduceProxy& srp, Histogram& histogram)
const;
40 void forward_histogram(Block* b,
const diy::ReduceProxy& srp,
const Histogram& histogram)
const;
42 void enqueue_exchange(Block* b,
const diy::ReduceProxy& srp,
int dim,
const Histogram& histogram)
const;
45 void update_neighbor_bounds(Bounds& bounds,
float split,
int dim,
bool lower)
const;
46 bool intersects(
const Bounds& x,
const Bounds& y,
int dim,
bool wrap,
const Bounds& domain)
const;
47 float find_split(
const Bounds& changed,
const Bounds& original)
const;
50 std::vector<Point> Block::* points_;
57 struct diy::detail::KDTreePartners
61 typedef std::pair<bool, int> RoundType;
64 KDTreePartners(
int dim,
int nblocks,
bool wrap_,
const Bounds& domain_):
65 decomposer(1,
interval(0,nblocks-1), nblocks),
66 histogram(decomposer, 2),
67 swap(decomposer, 2, false),
71 for (
unsigned i = 0; i < swap.rounds(); ++i)
74 for (
unsigned j = 0; j < histogram.rounds(); ++j)
76 rounds_.push_back(std::make_pair(
false, j));
77 dim_.push_back(i % dim);
78 if (j == histogram.rounds() / 2 - 1 - i)
83 rounds_.push_back(std::make_pair(
true, i));
84 dim_.push_back(i % dim);
87 rounds_.push_back(std::make_pair(
true, -1));
88 dim_.push_back(i % dim);
92 size_t rounds()
const {
return rounds_.size(); }
93 size_t swap_rounds()
const {
return swap.rounds(); }
95 int dim(
int round)
const {
return dim_[round]; }
96 bool swap_round(
int round)
const {
return rounds_[round].first; }
97 int sub_round(
int round)
const {
return rounds_[round].second; }
99 inline bool active(
int round,
int gid,
const diy::Master& m)
const
101 if (round == (
int) rounds())
103 else if (swap_round(round) && sub_round(round) < 0)
105 else if (swap_round(round))
106 return swap.active(sub_round(round), gid, m);
108 return histogram.active(sub_round(round), gid, m);
111 inline void incoming(
int round,
int gid, std::vector<int>& partners,
const diy::Master& m)
const
113 if (round == (
int) rounds())
114 link_neighbors(-1, gid, partners, m);
115 else if (swap_round(round) && sub_round(round) < 0)
116 swap.incoming(sub_round(round - 1) + 1, gid, partners, m);
117 else if (swap_round(round))
118 histogram.incoming(histogram.rounds(), gid, partners, m);
121 if (round > 0 && sub_round(round) == 0)
122 link_neighbors(-1, gid, partners, m);
123 else if (round > 0 && sub_round(round - 1) != sub_round(round) - 1)
124 histogram.incoming(sub_round(round - 1) + 1, gid, partners, m);
126 histogram.incoming(sub_round(round), gid, partners, m);
130 inline void outgoing(
int round,
int gid, std::vector<int>& partners,
const diy::Master& m)
const
132 if (round == (
int) rounds())
133 swap.outgoing(sub_round(round-1) + 1, gid, partners, m);
134 else if (swap_round(round) && sub_round(round) < 0)
135 link_neighbors(-1, gid, partners, m);
136 else if (swap_round(round))
137 swap.outgoing(sub_round(round), gid, partners, m);
139 histogram.outgoing(sub_round(round), gid, partners, m);
142 inline void link_neighbors(
int,
int gid, std::vector<int>& partners,
const diy::Master& m)
const
144 int lid = m.
lid(gid);
147 std::set<int> result;
148 for (
int i = 0; i < link->size(); ++i)
149 result.insert(link->target(i).gid);
151 for (std::set<int>::const_iterator it = result.begin(); it != result.end(); ++it)
152 partners.push_back(*it);
161 std::vector<RoundType> rounds_;
162 std::vector<int> dim_;
168 template<
class Block,
class Po
int>
170 diy::detail::KDTreePartition<Block,Point>::
171 operator()(Block* b,
const diy::ReduceProxy& srp,
const KDTreePartners& partners)
const
174 if (srp.
round() < partners.rounds())
175 dim = partners.dim(srp.
round());
177 dim = partners.dim(srp.
round() - 1);
179 if (srp.
round() == partners.rounds())
180 update_links(b, srp, dim, partners.sub_round(srp.
round() - 2), partners.swap_rounds(), partners.wrap, partners.domain);
181 else if (partners.swap_round(srp.
round()) && partners.sub_round(srp.
round()) < 0)
183 dequeue_exchange(b, srp, dim);
184 split_to_neighbors(b, srp, dim);
186 else if (partners.swap_round(srp.
round()))
189 receive_histogram(b, srp, histogram);
190 enqueue_exchange(b, srp, dim, histogram);
191 }
else if (partners.sub_round(srp.
round()) == 0)
195 int prev_dim = dim - 1;
198 update_links(b, srp, prev_dim, partners.sub_round(srp.
round() - 2), partners.swap_rounds(), partners.wrap, partners.domain);
201 compute_local_histogram(b, srp, dim);
202 }
else if (partners.sub_round(srp.
round()) < (
int) partners.histogram.rounds()/2)
204 Histogram histogram(bins_);
205 add_histogram(b, srp, histogram);
210 Histogram histogram(bins_);
211 add_histogram(b, srp, histogram);
212 forward_histogram(b, srp, histogram);
216 template<
class Block,
class Po
int>
218 diy::detail::KDTreePartition<Block,Point>::
219 divide_gid(
int gid,
bool lower,
int round,
int rounds)
const
222 gid &= ~(1 << (rounds - 1 - round));
224 gid |= (1 << (rounds - 1 - round));
229 template<
class Block,
class Po
int>
231 diy::detail::KDTreePartition<Block,Point>::
232 update_links(Block* b,
const diy::ReduceProxy& srp,
int dim,
int round,
int rounds,
bool wrap,
const Bounds& domain)
const
235 int lid = srp.master()->
lid(gid);
236 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
239 std::map<std::pair<int,diy::Direction>,
int> link_map;
240 for (
int i = 0; i < link->size(); ++i)
241 link_map[std::make_pair(link->target(i).gid, link->direction(i))] = i;
244 std::vector<float> splits(link->size());
245 for (
int i = 0; i < link->size(); ++i)
249 int in_gid = link->target(i).gid;
250 while(srp.incoming(in_gid))
256 for (
int j = 0; j < dim_; ++j)
259 int k = link_map[std::make_pair(in_gid, dir)];
264 RCLink new_link(dim_, link->core(), link->core());
266 bool lower = !(gid & (1 << (rounds - 1 - round)));
269 for (
int i = 0; i < link->size(); ++i)
275 if ((dir[dim] < 0 && lower) || (dir[dim] > 0 && !lower))
277 int nbr_gid = divide_gid(link->target(i).gid, !lower, round, rounds);
279 new_link.add_neighbor(nbr);
281 new_link.add_direction(dir);
283 Bounds bounds = link->bounds(i);
284 update_neighbor_bounds(bounds, splits[i], dim, !lower);
285 new_link.add_bounds(bounds);
288 new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
294 for (
int j = 0; j < 2; ++j)
296 int nbr_gid = divide_gid(link->target(i).gid, j == 0, round, rounds);
298 Bounds bounds = link->bounds(i);
299 update_neighbor_bounds(bounds, splits[i], dim, j == 0);
301 if (intersects(bounds, new_link.bounds(), dim, wrap, domain))
304 new_link.add_neighbor(nbr);
305 new_link.add_direction(dir);
306 new_link.add_bounds(bounds);
309 new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
318 int dual_gid = divide_gid(gid, !lower, round, rounds);
320 new_link.add_neighbor(dual);
322 Bounds nbr_bounds = link->bounds();
323 update_neighbor_bounds(nbr_bounds, find_split(new_link.bounds(), nbr_bounds), dim, !lower);
324 new_link.add_bounds(nbr_bounds);
332 new_link.add_direction(right);
337 new_link.add_direction(left);
343 link->swap(new_link);
346 template<
class Block,
class Po
int>
348 diy::detail::KDTreePartition<Block,Point>::
351 int lid = srp.master()->
lid(srp.gid());
352 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
355 float split = find_split(link->core(), link->bounds());
357 for (
int i = 0; i < link->size(); ++i)
359 srp.
enqueue(link->target(i), split);
360 srp.
enqueue(link->target(i), link->direction(i));
364 template<
class Block,
class Po
int>
366 diy::detail::KDTreePartition<Block,Point>::
367 compute_local_histogram(Block* b,
const diy::ReduceProxy& srp,
int dim)
const
369 int lid = srp.master()->
lid(srp.gid());
370 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
373 Histogram histogram(bins_);
375 float width = (link->core().max[dim] - link->core().min[dim])/bins_;
376 for (
size_t i = 0; i < (b->*points_).size(); ++i)
378 float x = (b->*points_)[i][dim];
379 int loc = (x - link->core().min[dim]) / width;
381 throw std::runtime_error(fmt::format(
"{} {} {}", loc, x, link->core().min[dim]));
382 if (loc >= (
int) bins_)
390 template<
class Block,
class Po
int>
392 diy::detail::KDTreePartition<Block,Point>::
393 add_histogram(Block* b,
const diy::ReduceProxy& srp, Histogram& histogram)
const
396 for (
int i = 0; i < srp.
in_link().size(); ++i)
398 int nbr_gid = srp.
in_link().target(i).gid;
402 for (
size_t i = 0; i < hist.size(); ++i)
403 histogram[i] += hist[i];
407 template<
class Block,
class Po
int>
409 diy::detail::KDTreePartition<Block,Point>::
410 receive_histogram(Block* b,
const diy::ReduceProxy& srp, Histogram& histogram)
const
415 template<
class Block,
class Po
int>
417 diy::detail::KDTreePartition<Block,Point>::
418 forward_histogram(Block* b,
const diy::ReduceProxy& srp,
const Histogram& histogram)
const
420 for (
int i = 0; i < srp.
out_link().size(); ++i)
424 template<
class Block,
class Po
int>
426 diy::detail::KDTreePartition<Block,Point>::
427 enqueue_exchange(Block* b,
const diy::ReduceProxy& srp,
int dim,
const Histogram& histogram)
const
429 auto log = get_logger();
431 int lid = srp.master()->
lid(srp.gid());
432 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
441 for (
size_t i = 0; i < histogram.size(); ++i)
442 total += histogram[i];
443 log->trace(
"Histogram total: {}", total);
446 float width = (link->core().max[dim] - link->core().min[dim])/bins_;
448 for (
size_t i = 0; i < histogram.size(); ++i)
450 if (cur + histogram[i] > total/2)
452 split = link->core().min[dim] + width*i;
457 log->trace(
"Found split: {} (dim={}) in {} - {}", split, dim, link->core().min[dim], link->core().max[dim]);
460 std::vector< std::vector<Point> > out_points(srp.
out_link().size());
461 for (
size_t i = 0; i < (b->*points_).size(); ++i)
463 float x = (b->*points_)[i][dim];
464 int loc = x < split ? 0 : 1;
465 out_points[loc].push_back((b->*points_)[i]);
468 for (
int i = 0; i < k; ++i)
470 if (srp.
out_link().target(i).gid == srp.gid())
472 (b->*points_).swap(out_points[i]);
479 link->core().max[dim] = split;
481 link->core().min[dim] = split;
484 template<
class Block,
class Po
int>
486 diy::detail::KDTreePartition<Block,Point>::
489 int lid = srp.master()->
lid(srp.gid());
490 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
492 for (
int i = 0; i < srp.
in_link().size(); ++i)
494 int nbr_gid = srp.
in_link().target(i).gid;
495 if (nbr_gid == srp.gid())
498 std::vector<Point> in_points;
499 srp.
dequeue(nbr_gid, in_points);
500 for (
size_t j = 0; j < in_points.size(); ++j)
502 if (in_points[j][dim] < link->core().min[dim] || in_points[j][dim] > link->core().max[dim])
503 throw std::runtime_error(fmt::format(
"Dequeued {} outside [{},{}] ({})",
504 in_points[j][dim], link->core().min[dim], link->core().max[dim], dim));
505 (b->*points_).push_back(in_points[j]);
510 template<
class Block,
class Po
int>
512 diy::detail::KDTreePartition<Block,Point>::
513 update_neighbor_bounds(Bounds& bounds,
float split,
int dim,
bool lower)
const
516 bounds.max[dim] = split;
518 bounds.min[dim] = split;
521 template<
class Block,
class Po
int>
523 diy::detail::KDTreePartition<Block,Point>::
524 intersects(
const Bounds& x,
const Bounds& y,
int dim,
bool wrap,
const Bounds& domain)
const
528 if (x.min[dim] == domain.min[dim] && y.max[dim] == domain.max[dim])
530 if (y.min[dim] == domain.min[dim] && x.max[dim] == domain.max[dim])
533 return x.min[dim] <= y.max[dim] && y.min[dim] <= x.max[dim];
536 template<
class Block,
class Po
int>
538 diy::detail::KDTreePartition<Block,Point>::
539 find_split(
const Bounds& changed,
const Bounds& original)
const
541 for (
int i = 0; i < dim_; ++i)
543 if (changed.min[i] != original.min[i])
544 return changed.min[i];
545 if (changed.max[i] != original.max[i])
546 return changed.max[i];
552 template<
class Block,
class Po
int>
554 diy::detail::KDTreePartition<Block,Point>::
555 find_wrap(
const Bounds& bounds,
const Bounds& nbr_bounds,
const Bounds& domain)
const
558 for (
int i = 0; i < dim_; ++i)
560 if (bounds.min[i] == domain.min[i] && nbr_bounds.max[i] == domain.max[i])
562 if (bounds.max[i] == domain.max[i] && nbr_bounds.min[i] == domain.min[i])
const Link & out_link() const
returns outgoing link
Definition: reduce.hpp:70
Enables communication within a group during a reduction. DIY creates the ReduceProxy for you in diy::...
Definition: reduce.hpp:15
void dequeue(int from, T &x, void(*load)(BinaryBuffer &, T &)=&::diy::load< T >) const
Dequeue data whose size can be determined automatically (e.g., STL vector) and that was previously en...
Definition: proxy.hpp:42
unsigned round() const
returns current round number
Definition: reduce.hpp:66
Decomposes a regular (discrete or continuous) domain into even blocks; creates Links with Bounds alon...
Definition: decomposition.hpp:75
virtual int rank(int gid) const =0
returns the process rank of the block with global id gid (need not be local)
Allreduce (reduction with results broadcasted to all blocks) is implemented as two merge reductions...
Definition: all-reduce.hpp:20
const Link & in_link() const
returns incoming link
Definition: reduce.hpp:68
Definition: master.hpp:35
diy::DiscreteBounds interval(int from, int to)
Helper to create a 1-dimensional discrete domain with the specified extents.
Definition: types.hpp:28
const Assigner & assigner() const
returns the assigner
Definition: reduce.hpp:74
Partners for swap-reduce.
Definition: swap.hpp:16
int lid(int gid) const
return the local id of the local block with global id gid, or -1 if not local
Definition: master.hpp:221
void enqueue(const BlockID &to, const T &x, void(*save)(BinaryBuffer &, const T &)=&::diy::save< T >) const
Enqueue data whose size can be determined automatically, e.g., an STL vector.
Definition: proxy.hpp:24