1 #ifndef DIY_DETAIL_ALGORITHMS_KDTREE_SAMPLING_HPP
2 #define DIY_DETAIL_ALGORITHMS_KDTREE_SAMPLING_HPP
6 #include "../../partners/all-reduce.hpp"
7 #include "../../log.hpp"
23 template<
class Block,
class Po
int>
24 struct KDTreeSamplingPartition
29 typedef std::vector<float> Samples;
31 KDTreeSamplingPartition(
int dim,
32 std::vector<Point> Block::* points,
34 dim_(dim), points_(points), samples_(samples) {}
36 void operator()(Block* b,
const diy::ReduceProxy& srp,
const KDTreePartners& partners)
const;
38 int divide_gid(
int gid,
bool lower,
int round,
int rounds)
const;
39 void update_links(Block* b,
const diy::ReduceProxy& srp,
int dim,
int round,
int rounds,
bool wrap,
const Bounds& domain)
const;
40 void split_to_neighbors(Block* b,
const diy::ReduceProxy& srp,
int dim)
const;
42 find_wrap(
const Bounds& bounds,
const Bounds& nbr_bounds,
const Bounds& domain)
const;
44 void compute_local_samples(Block* b,
const diy::ReduceProxy& srp,
int dim)
const;
45 void add_samples(Block* b,
const diy::ReduceProxy& srp, Samples& samples)
const;
46 void receive_samples(Block* b,
const diy::ReduceProxy& srp, Samples& samples)
const;
47 void forward_samples(Block* b,
const diy::ReduceProxy& srp,
const Samples& samples)
const;
49 void enqueue_exchange(Block* b,
const diy::ReduceProxy& srp,
int dim,
const Samples& samples)
const;
52 void update_neighbor_bounds(Bounds& bounds,
float split,
int dim,
bool lower)
const;
53 bool intersects(
const Bounds& x,
const Bounds& y,
int dim,
bool wrap,
const Bounds& domain)
const;
54 float find_split(
const Bounds& changed,
const Bounds& original)
const;
57 std::vector<Point> Block::* points_;
65 template<
class Block,
class Po
int>
67 diy::detail::KDTreeSamplingPartition<Block,Point>::
68 operator()(Block* b,
const diy::ReduceProxy& srp,
const KDTreePartners& partners)
const
71 if (srp.
round() < partners.rounds())
72 dim = partners.dim(srp.
round());
74 dim = partners.dim(srp.
round() - 1);
76 if (srp.
round() == partners.rounds())
77 update_links(b, srp, dim, partners.sub_round(srp.
round() - 2), partners.swap_rounds(), partners.wrap, partners.domain);
78 else if (partners.swap_round(srp.
round()) && partners.sub_round(srp.
round()) < 0)
80 dequeue_exchange(b, srp, dim);
81 split_to_neighbors(b, srp, dim);
83 else if (partners.swap_round(srp.
round()))
86 receive_samples(b, srp, samples);
87 enqueue_exchange(b, srp, dim, samples);
88 }
else if (partners.sub_round(srp.
round()) == 0)
92 int prev_dim = dim - 1;
95 update_links(b, srp, prev_dim, partners.sub_round(srp.
round() - 2), partners.swap_rounds(), partners.wrap, partners.domain);
98 compute_local_samples(b, srp, dim);
99 }
else if (partners.sub_round(srp.
round()) < (
int) partners.histogram.rounds()/2)
102 add_samples(b, srp, samples);
107 add_samples(b, srp, samples);
108 if (samples.size() != 1)
111 std::nth_element(samples.begin(), samples.begin() + samples.size()/2, samples.end());
112 std::swap(samples[0], samples[samples.size()/2]);
117 forward_samples(b, srp, samples);
121 template<
class Block,
class Po
int>
123 diy::detail::KDTreeSamplingPartition<Block,Point>::
124 divide_gid(
int gid,
bool lower,
int round,
int rounds)
const
127 gid &= ~(1 << (rounds - 1 - round));
129 gid |= (1 << (rounds - 1 - round));
134 template<
class Block,
class Po
int>
136 diy::detail::KDTreeSamplingPartition<Block,Point>::
137 update_links(Block* b,
const diy::ReduceProxy& srp,
int dim,
int round,
int rounds,
bool wrap,
const Bounds& domain)
const
139 auto log = get_logger();
141 int lid = srp.master()->
lid(gid);
142 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
145 std::map<std::pair<int,diy::Direction>,
int> link_map;
146 for (
int i = 0; i < link->size(); ++i)
147 link_map[std::make_pair(link->target(i).gid, link->direction(i))] = i;
150 std::vector<float> splits(link->size());
151 for (
int i = 0; i < link->size(); ++i)
155 int in_gid = link->target(i).gid;
156 while(srp.incoming(in_gid))
162 for (
int j = 0; j < dim_; ++j)
165 int k = link_map[std::make_pair(in_gid, dir)];
166 log->trace(
"{} {} {} -> {}", in_gid, dir, split, k);
171 RCLink new_link(dim_, link->core(), link->core());
173 bool lower = !(gid & (1 << (rounds - 1 - round)));
176 for (
int i = 0; i < link->size(); ++i)
182 if ((dir[dim] < 0 && lower) || (dir[dim] > 0 && !lower))
184 int nbr_gid = divide_gid(link->target(i).gid, !lower, round, rounds);
186 new_link.add_neighbor(nbr);
188 new_link.add_direction(dir);
190 Bounds bounds = link->bounds(i);
191 update_neighbor_bounds(bounds, splits[i], dim, !lower);
192 new_link.add_bounds(bounds);
195 new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
201 for (
int j = 0; j < 2; ++j)
203 int nbr_gid = divide_gid(link->target(i).gid, j == 0, round, rounds);
205 Bounds bounds = link->bounds(i);
206 update_neighbor_bounds(bounds, splits[i], dim, j == 0);
208 if (intersects(bounds, new_link.bounds(), dim, wrap, domain))
211 new_link.add_neighbor(nbr);
212 new_link.add_direction(dir);
213 new_link.add_bounds(bounds);
216 new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
225 int dual_gid = divide_gid(gid, !lower, round, rounds);
227 new_link.add_neighbor(dual);
229 Bounds nbr_bounds = link->bounds();
230 update_neighbor_bounds(nbr_bounds, find_split(new_link.bounds(), nbr_bounds), dim, !lower);
231 new_link.add_bounds(nbr_bounds);
239 new_link.add_direction(right);
244 new_link.add_direction(left);
250 link->swap(new_link);
253 template<
class Block,
class Po
int>
255 diy::detail::KDTreeSamplingPartition<Block,Point>::
258 int lid = srp.master()->
lid(srp.gid());
259 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
262 float split = find_split(link->core(), link->bounds());
264 for (
int i = 0; i < link->size(); ++i)
266 srp.
enqueue(link->target(i), split);
267 srp.
enqueue(link->target(i), link->direction(i));
271 template<
class Block,
class Po
int>
273 diy::detail::KDTreeSamplingPartition<Block,Point>::
278 size_t points_size = (b->*points_).size();
279 size_t n = std::min(points_size, samples_);
281 for (
size_t i = 0; i < n; ++i)
283 float x = (b->*points_)[rand() % points_size][dim];
284 samples.push_back(x);
290 template<
class Block,
class Po
int>
292 diy::detail::KDTreeSamplingPartition<Block,Point>::
296 for (
int i = 0; i < srp.
in_link().size(); ++i)
298 int nbr_gid = srp.
in_link().target(i).gid;
302 for (
size_t i = 0; i < smpls.size(); ++i)
303 samples.push_back(smpls[i]);
307 template<
class Block,
class Po
int>
309 diy::detail::KDTreeSamplingPartition<Block,Point>::
310 receive_samples(Block* b,
const diy::ReduceProxy& srp, Samples& samples)
const
315 template<
class Block,
class Po
int>
317 diy::detail::KDTreeSamplingPartition<Block,Point>::
318 forward_samples(Block* b,
const diy::ReduceProxy& srp,
const Samples& samples)
const
320 for (
int i = 0; i < srp.
out_link().size(); ++i)
324 template<
class Block,
class Po
int>
326 diy::detail::KDTreeSamplingPartition<Block,Point>::
327 enqueue_exchange(Block* b,
const diy::ReduceProxy& srp,
int dim,
const Samples& samples)
const
329 int lid = srp.master()->
lid(srp.gid());
330 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
338 float split = samples[0];
341 std::vector< std::vector<Point> > out_points(srp.
out_link().size());
342 for (
size_t i = 0; i < (b->*points_).size(); ++i)
344 float x = (b->*points_)[i][dim];
345 int loc = x < split ? 0 : 1;
346 out_points[loc].push_back((b->*points_)[i]);
349 for (
int i = 0; i < k; ++i)
351 if (srp.
out_link().target(i).gid == srp.gid())
353 (b->*points_).swap(out_points[i]);
360 link->core().max[dim] = split;
362 link->core().min[dim] = split;
365 template<
class Block,
class Po
int>
367 diy::detail::KDTreeSamplingPartition<Block,Point>::
370 int lid = srp.master()->
lid(srp.gid());
371 RCLink* link =
static_cast<RCLink*
>(srp.master()->link(lid));
373 for (
int i = 0; i < srp.
in_link().size(); ++i)
375 int nbr_gid = srp.
in_link().target(i).gid;
376 if (nbr_gid == srp.gid())
379 std::vector<Point> in_points;
380 srp.
dequeue(nbr_gid, in_points);
381 for (
size_t j = 0; j < in_points.size(); ++j)
383 if (in_points[j][dim] < link->core().min[dim] || in_points[j][dim] > link->core().max[dim])
384 throw std::runtime_error(fmt::format(
"Dequeued {} outside [{},{}] ({})",
385 in_points[j][dim], link->core().min[dim], link->core().max[dim], dim));
386 (b->*points_).push_back(in_points[j]);
391 template<
class Block,
class Po
int>
393 diy::detail::KDTreeSamplingPartition<Block,Point>::
394 update_neighbor_bounds(Bounds& bounds,
float split,
int dim,
bool lower)
const
397 bounds.max[dim] = split;
399 bounds.min[dim] = split;
402 template<
class Block,
class Po
int>
404 diy::detail::KDTreeSamplingPartition<Block,Point>::
405 intersects(
const Bounds& x,
const Bounds& y,
int dim,
bool wrap,
const Bounds& domain)
const
409 if (x.min[dim] == domain.min[dim] && y.max[dim] == domain.max[dim])
411 if (y.min[dim] == domain.min[dim] && x.max[dim] == domain.max[dim])
414 return x.min[dim] <= y.max[dim] && y.min[dim] <= x.max[dim];
417 template<
class Block,
class Po
int>
419 diy::detail::KDTreeSamplingPartition<Block,Point>::
420 find_split(
const Bounds& changed,
const Bounds& original)
const
422 for (
int i = 0; i < dim_; ++i)
424 if (changed.min[i] != original.min[i])
425 return changed.min[i];
426 if (changed.max[i] != original.max[i])
427 return changed.max[i];
433 template<
class Block,
class Po
int>
435 diy::detail::KDTreeSamplingPartition<Block,Point>::
436 find_wrap(
const Bounds& bounds,
const Bounds& nbr_bounds,
const Bounds& domain)
const
439 for (
int i = 0; i < dim_; ++i)
441 if (bounds.min[i] == domain.min[i] && nbr_bounds.max[i] == domain.max[i])
443 if (bounds.max[i] == domain.max[i] && nbr_bounds.min[i] == domain.min[i])
const Link & out_link() const
returns outgoing link
Definition: reduce.hpp:70
Enables communication within a group during a reduction. DIY creates the ReduceProxy for you in diy::...
Definition: reduce.hpp:15
void dequeue(int from, T &x, void(*load)(BinaryBuffer &, T &)=&::diy::load< T >) const
Dequeue data whose size can be determined automatically (e.g., STL vector) and that was previously en...
Definition: proxy.hpp:42
unsigned round() const
returns current round number
Definition: reduce.hpp:66
virtual int rank(int gid) const =0
returns the process rank of the block with global id gid (need not be local)
const Link & in_link() const
returns incoming link
Definition: reduce.hpp:68
const Assigner & assigner() const
returns the assigner
Definition: reduce.hpp:74
int lid(int gid) const
return the local id of the local block with global id gid, or -1 if not local
Definition: master.hpp:221
void enqueue(const BlockID &to, const T &x, void(*save)(BinaryBuffer &, const T &)=&::diy::save< T >) const
Enqueue data whose size can be determined automatically, e.g., an STL vector.
Definition: proxy.hpp:24