SpatialLocator.cpp 18.7 KB
Newer Older
1
2
3
#include "moab/SpatialLocator.hpp"
#include "moab/Interface.hpp"
#include "moab/ElemEvaluator.hpp"
4
#include "moab/AdaptiveKDTree.hpp"
5
#include "moab/BVHTree.hpp"
6

7
8
9
// include ScdInterface for box partitioning
#include "moab/ScdInterface.hpp"

10
#ifdef MOAB_HAVE_MPI
11
12
13
#include "moab/ParallelComm.hpp"
#endif

14
bool debug = false;
15

16
17
18
19
namespace moab 
{

    SpatialLocator::SpatialLocator(Interface *impl, Range &elems, Tree *tree, ElemEvaluator *eval) 
20
21
            : mbImpl(impl), myElems(elems), myDim(-1), myTree(tree), elemEval(eval), iCreatedTree(false),
              timerInitialized(false)
22
    {
23
24
      create_tree();
      
25
26
27
28
      if (!elems.empty()) {
        myDim = mbImpl->dimension_from_handle(*elems.rbegin());
        ErrorCode rval = myTree->build_tree(myElems);
        if (MB_SUCCESS != rval) throw rval;
29
30
31

        rval = myTree->get_bounding_box(localBox);
        if (MB_SUCCESS != rval) throw rval;
32
      }
33
34
    }

35
36
37
38
    void SpatialLocator::create_tree() 
    {
      if (myTree) return;
      
39
      if (myElems.empty() || mbImpl->type_from_handle(*myElems.rbegin()) == MBVERTEX) 
40
41
          // create a kdtree if only vertices
        myTree = new AdaptiveKDTree(mbImpl);
42
      else
43
          // otherwise a BVHtree, since it performs better for elements
44
        myTree = new BVHTree(mbImpl);
45
46
47
48

      iCreatedTree = true;
    }

49
50
51
52
53
54
55
56
    ErrorCode SpatialLocator::add_elems(Range &elems) 
    {
      if (elems.empty() ||
          mbImpl->dimension_from_handle(*elems.begin()) != mbImpl->dimension_from_handle(*elems.rbegin()))
        return MB_FAILURE;
  
      myDim = mbImpl->dimension_from_handle(*elems.begin());
      myElems = elems;
57
58
59

      ErrorCode rval = myTree->build_tree(myElems);
      return rval;
60
61
    }
    
62
#ifdef MOAB_HAVE_MPI
63
    ErrorCode SpatialLocator::initialize_intermediate_partition(ParallelComm *pc) 
64
    {
65
66
      if (!pc) return MB_FAILURE;
      
67
      BoundBox gbox;
68
69
70
71
72
73
      
        //step 2
        // get the global bounding box
      double sendbuffer[6];
      double rcvbuffer[6];

74
75
76
77
      localBox.get(sendbuffer); //fill sendbuffer with local box, max values in [0:2] min values in [3:5]
      sendbuffer[0] *= -1;
      sendbuffer[1] *= -1; //negate Xmin,Ymin,Zmin to get their minimum using MPI_MAX
      sendbuffer[2] *= -1; //to avoid calling MPI_Allreduce again with MPI_MIN
78
79
80
81

      int mpi_err = MPI_Allreduce(sendbuffer, rcvbuffer, 6, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
      if (MPI_SUCCESS != mpi_err)	return MB_FAILURE;

82
83
84
      rcvbuffer[0] *= -1;
      rcvbuffer[1] *= -1;  //negate Xmin,Ymin,Zmin again to get original values
      rcvbuffer[2] *= -1;
85

86
87
      globalBox.update_max(&rcvbuffer[3]); //saving values in globalBox
      globalBox.update_min(&rcvbuffer[0]);
88
89
90
91
92

        // compute the alternate decomposition; use ScdInterface::compute_partition_sqijk for this
      ScdParData spd;
      spd.partMethod = ScdParData::SQIJK;
      spd.gPeriodic[0] = spd.gPeriodic[1] = spd.gPeriodic[2] = 0;
93
      double lg = log10((localBox.bMax - localBox.bMin).length());
94
      double mfactor = pow(10.0, 6 - lg);
95
      int ldims[6], lper[3];
96
      double dgijk[6];
97
      localBox.get(dgijk);
98
      for (int i = 0; i < 6; i++) spd.gDims[i] = dgijk[i] * mfactor;
99
100
      ErrorCode rval = ScdInterface::compute_partition(pc->size(), pc->rank(), spd,
                                                       ldims, lper, regNums);
101
102
103
104
105
106
107
108
109
110
111
      if (MB_SUCCESS != rval) return rval;
        // all we're really interested in is regNums[i], #procs in each direction
      
      for (int i = 0; i < 3; i++)
        regDeltaXYZ[i] = (globalBox.bMax[i] - globalBox.bMin[i])/double(regNums[i]); //size of each region

      return MB_SUCCESS;
    }

//this function sets up the TupleList TLreg_o containing the registration messages
// and sends it
112
    ErrorCode SpatialLocator::register_src_with_intermediate_procs(ParallelComm *pc, double abs_iter_tol, TupleList &TLreg_o)
113
    {
114
115
116
117
118
119
120
121
      int corner_ijk[6];

        // step 3: compute ijks of local box corners in intermediate partition
        // get corner ijk values for my box
      ErrorCode rval = get_point_ijk(localBox.bMin-CartVect(abs_iter_tol), abs_iter_tol, corner_ijk);
      if (MB_SUCCESS != rval) return rval;
      rval = get_point_ijk(localBox.bMax+CartVect(abs_iter_tol), abs_iter_tol, corner_ijk+3);
      if (MB_SUCCESS != rval) return rval;
122
123
124
125
126
127
128
129
130

        //step 4
        //set up TLreg_o
      TLreg_o.initialize(1,0,0,6,0);
        // TLreg_o (int destProc, real Xmin, Ymin, Zmin, Xmax, Ymax, Zmax)

      int dest;
      double boxtosend[6];

131
      localBox.get(boxtosend);
132
133

        //iterate over all regions overlapping with my bounding box using the computerd corner IDs
134
135
136
      for (int k = corner_ijk[2]; k <= corner_ijk[5]; k++) {
        for (int j = corner_ijk[1]; j <= corner_ijk[4]; j++) {
          for (int i = corner_ijk[0]; i <= corner_ijk[3]; i++) {
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            dest = k * regNums[0]*regNums[1] + j * regNums[0] + i;
            TLreg_o.push_back(&dest, NULL, NULL, boxtosend);
          }
        }
      }
	
        //step 5
        //send TLreg_o, receive TLrequests_i
      if (pc) pc->proc_config().crystal_router()->gs_transfer(1, TLreg_o, 0);

        //step 6
        //Read registration requests from TLreg_o and add to list of procs to forward to
        //get number of tuples sent to me

        //read tuples and fill processor list;
      int NN = TLreg_o.get_n();
      for (int i=0; i < NN; i++)
          //TLreg_o is now TLrequests_i
        srcProcBoxes[TLreg_o.vi_rd[i]] = BoundBox(TLreg_o.vr_rd+6*i);

      return MB_SUCCESS;
158
159
    }

160
161
    ErrorCode SpatialLocator::par_locate_points(ParallelComm */*pc*/,
                                                Range &/*vertices*/,
162
163
                                                const double /*rel_iter_tol*/, const double /*abs_iter_tol*/,
                                                const double /*inside_tol*/)
164
165
166
    {
      return MB_UNSUPPORTED_OPERATION;
    }
167

168
    bool is_neg(int i) {return (i == -1);}
169
170
171
172
173
174
175
176
177
178
179
180
181
      
    ErrorCode SpatialLocator::par_locate_points(ParallelComm *pc,
                                                const double *pos, int num_points,
                                                const double rel_iter_tol, const double abs_iter_tol,
                                                const double inside_tol)
    {
      ErrorCode rval;
        //TUpleList used for communication 
      TupleList TLreg_o;   //TLregister_outbound
      TupleList TLquery_o; //TLquery_outbound
      TupleList TLforward_o; //TLforward_outbound
      TupleList TLsearch_results_o; //TLsearch_results_outbound

182
183
184
        // initialize timer 
      myTimer.time_elapsed();
      timerInitialized = true;
185
186
      
        // steps 1-2 - initialize the alternative decomposition box from global box
187
      rval = initialize_intermediate_partition(pc);
188
      if (rval != MB_SUCCESS) return rval;
189
      
190
191
        //steps 3-6 - set up TLreg_o, gs_transfer, gather registrations
      rval = register_src_with_intermediate_procs(pc, abs_iter_tol, TLreg_o);
192
193
      if (rval != MB_SUCCESS) return rval;

194
195
      myTimes.slTimes[SpatialLocatorTimes::INTMED_INIT] = myTimer.time_elapsed();

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        // actual parallel point location using intermediate partition

        // target_pts: TL(to_proc, tgt_index, x, y, z): tuples sent to source mesh procs representing pts to be located
        // source_pts: TL(from_proc, tgt_index, src_index): results of source mesh proc point location, ready to send
        //             back to tgt procs; src_index of -1 indicates point not located (arguably not useful...)

      unsigned int my_rank = (pc? pc->proc_config().proc_rank() : 0);

        //TLquery_o: Tuples sent to forwarder proc 
        //TL (toProc, OriginalSourceProc, targetIndex, X,Y,Z)

        //TLforw_req_i: Tuples to forward to corresponding procs (forwarding requests)
        //TL (sourceProc, OriginalSourceProc, targetIndex, X,Y,Z)

      TLquery_o.initialize(3,0,0,3,0);

      int iargs[3];

      for (int pnt=0; pnt < 3*num_points; pnt+=3)
      {
216
        int forw_id = proc_from_point(pos+pnt, abs_iter_tol); //get ID of proc resonsible of the region the proc is in
217
218
219
220
221
222
223
224
225
226
227
228

        iargs[0] = forw_id; 	//toProc
        iargs[1] = my_rank; 	//originalSourceProc
        iargs[2] = pnt/3;    	//targetIndex 	

        TLquery_o.push_back(iargs, NULL, NULL, const_cast<double*>(pos+pnt));
      }

        //send point search queries to forwarders
      if (pc)
        pc->proc_config().crystal_router()->gs_transfer(1, TLquery_o, 0);

229
230
      myTimes.slTimes[SpatialLocatorTimes::INTMED_SEND] = myTimer.time_elapsed();

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        //now read forwarding requests and forward to corresponding procs
        //TLquery_o is now TLforw_req_i

        //TLforward_o: query messages forwarded to corresponding procs
        //TL (toProc, OriginalSourceProc, targetIndex, X,Y,Z)

      TLforward_o.initialize(3,0,0,3,0);

      int NN = TLquery_o.get_n();

      for (int i=0; i < NN; i++) {
        iargs[1] = TLquery_o.vi_rd[3*i+1];	//get OriginalSourceProc
        iargs[2] = TLquery_o.vi_rd[3*i+2];	//targetIndex
        CartVect tmp_pnt(TLquery_o.vr_rd+3*i);

          //compare coordinates to list of bounding boxes
247
        for (std::map<int, BoundBox>::iterator mit = srcProcBoxes.begin(); mit != srcProcBoxes.end(); ++mit) {
248
249
250
251
252
253
254
255
          if ((*mit).second.contains_point(tmp_pnt.array(), abs_iter_tol)) {
            iargs[0] = (*mit).first;
            TLforward_o.push_back(iargs, NULL, NULL, tmp_pnt.array());
          }
        }

      }

256
257
      myTimes.slTimes[SpatialLocatorTimes::INTMED_SEARCH] = myTimer.time_elapsed();

258
259
260
      if (pc)
        pc->proc_config().crystal_router()->gs_transfer(1, TLforward_o, 0);

261
262
263
264
265
266
      myTimes.slTimes[SpatialLocatorTimes::SRC_SEND] = myTimer.time_elapsed();

        // cache time here, because locate_points also calls elapsed functions and we want to account
        // for tuple list initialization here
      double tstart = myTimer.time_since_birth();
      
267
268
269
270
271
272
273
274
275
        //step 12
        //now read Point Search requests
        //TLforward_o is now TLsearch_req_i
        //TLsearch_req_i: (sourceProc, OriginalSourceProc, targetIndex, X,Y,Z)
							  
      NN = TLforward_o.get_n();

        //TLsearch_results_o
        //TL: (OriginalSourceProc, targetIndex, sourceIndex, U,V,W);
276
      TLsearch_results_o.initialize(3,0,0,0,0);
277
278
279
280
281
282
283
284
285
286
287
288

        //step 13 is done in test_local_box

      std::vector<double> params(3*NN);
      std::vector<int> is_inside(NN, 0);
      std::vector<EntityHandle> ents(NN, 0);
      
      rval = locate_points(TLforward_o.vr_rd, TLforward_o.get_n(), 
                           &ents[0], &params[0], &is_inside[0], 
                           rel_iter_tol, abs_iter_tol, inside_tol);
      if (MB_SUCCESS != rval)
        return rval;
289
      
290
291
      locTable.initialize(1, 0, 1, 3, 0);
      locTable.enableWriteAccess();
292
293
294
295
296
297
      for (int i = 0; i < NN; i++) {
        if (is_inside[i]) {
          iargs[0] = TLforward_o.vi_rd[3*i+1];
          iargs[1] = TLforward_o.vi_rd[3*i+2];
          iargs[2] = locTable.get_n();
          TLsearch_results_o.push_back(iargs, NULL, NULL, NULL);
298
299
300
          ulong ent_ulong=(ulong)ents[i];
          sint forward= (sint)TLforward_o.vi_rd[3*i+1];
          locTable.push_back(&forward, NULL, &ent_ulong, &params[3*i]);
301
302
        }
      }
303
      locTable.disableWriteAccess();
304

305
306
307
      myTimes.slTimes[SpatialLocatorTimes::SRC_SEARCH] =  myTimer.time_since_birth() - tstart;
      myTimer.time_elapsed(); // call this to reset last time called

308
309
310
311
        //step 14: send TLsearch_results_o and receive TLloc_i
      if (pc)
        pc->proc_config().crystal_router()->gs_transfer(1, TLsearch_results_o, 0);

312
      myTimes.slTimes[SpatialLocatorTimes::TARG_RETURN] = myTimer.time_elapsed();
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        // store proc/index tuples in parLocTable
      parLocTable.initialize(2, 0, 0, 0, num_points);
      parLocTable.enableWriteAccess();
      std::fill(parLocTable.vi_wr, parLocTable.vi_wr + 2*num_points, -1);
      
      for (unsigned int i = 0; i < TLsearch_results_o.get_n(); i++) {
        int idx = TLsearch_results_o.vi_rd[3*i+1];
        parLocTable.vi_wr[2*idx] = TLsearch_results_o.vi_rd[3*i];
        parLocTable.vi_wr[2*idx+1] = TLsearch_results_o.vi_rd[3*i+2];
      }

      if (debug) {
        int num_found = num_points - 0.5 * 
327
328
329
            std::count_if(parLocTable.vi_wr, parLocTable.vi_wr + 2*num_points, is_neg);
        std::cout << "Points found = " << num_found << "/" << num_points 
                  << " (" << 100.0*((double)num_found/num_points) << "%)" << std::endl;
330
331
      }
      
332
333
      myTimes.slTimes[SpatialLocatorTimes::TARG_STORE] = myTimer.time_elapsed();

334
335
336
337
338
      return MB_SUCCESS;
    }

#endif

339
    ErrorCode SpatialLocator::locate_points(Range &verts,
340
341
                                            const double rel_iter_tol, const double abs_iter_tol, 
                                            const double inside_tol) 
342
    {
343
344
345
346
347
348
349
      bool i_initialized = false;
      if (!timerInitialized) {
        myTimer.time_elapsed();
        timerInitialized = true;
        i_initialized = true;
      }
      
350
351
352
353
      assert(!verts.empty() && mbImpl->type_from_handle(*verts.rbegin()) == MBVERTEX);
      std::vector<double> pos(3*verts.size());
      ErrorCode rval = mbImpl->get_coords(verts, &pos[0]);
      if (MB_SUCCESS != rval) return rval;
354
      rval = locate_points(&pos[0], verts.size(), rel_iter_tol, abs_iter_tol, inside_tol);
355
356
      if (MB_SUCCESS != rval) return rval;
      
357
358
359
360
        // only call this if I'm the top-level function, since it resets the last time called
      if (i_initialized) 
        myTimes.slTimes[SpatialLocatorTimes::SRC_SEARCH] =  myTimer.time_elapsed();

361
362
363
364
      return MB_SUCCESS;
    }
    
    ErrorCode SpatialLocator::locate_points(const double *pos, int num_points,
365
366
                                            const double rel_iter_tol, const double abs_iter_tol, 
                                            const double inside_tol) 
367
    {
368
369
370
371
372
373
      bool i_initialized = false;
      if (!timerInitialized) {
        myTimer.time_elapsed();
        timerInitialized = true;
        i_initialized = true;
      }
374
375
376
377
378
        // initialize to tuple structure (p_ui, hs_ul, r[3]_d) (see header comments for locTable)
      locTable.initialize(1, 0, 1, 3, num_points);
      locTable.enableWriteAccess();

        // pass storage directly into locate_points, since we know those arrays are contiguous
379
      ErrorCode rval = locate_points(pos, num_points, (EntityHandle*)locTable.vul_wr, locTable.vr_wr, NULL, rel_iter_tol, abs_iter_tol,
380
                                     inside_tol);
381
      std::fill(locTable.vi_wr, locTable.vi_wr+num_points, 0);
382
      locTable.set_n(num_points);
383
      if (MB_SUCCESS != rval) return rval;
384

385
      
386
387
388
389
        // only call this if I'm the top-level function, since it resets the last time called
      if (i_initialized) 
        myTimes.slTimes[SpatialLocatorTimes::SRC_SEARCH] =  myTimer.time_elapsed();

390
391
392
393
      return MB_SUCCESS;
    }
      
    ErrorCode SpatialLocator::locate_points(Range &verts,
394
                                            EntityHandle *ents, double *params, int *is_inside,
395
396
                                            const double rel_iter_tol, const double abs_iter_tol, 
                                            const double inside_tol)
397
    {
398
399
400
401
402
403
404
      bool i_initialized = false;
      if (!timerInitialized) {
        myTimer.time_elapsed();
        timerInitialized = true;
        i_initialized = true;
      }

405
406
407
408
      assert(!verts.empty() && mbImpl->type_from_handle(*verts.rbegin()) == MBVERTEX);
      std::vector<double> pos(3*verts.size());
      ErrorCode rval = mbImpl->get_coords(verts, &pos[0]);
      if (MB_SUCCESS != rval) return rval;
409
410
411
412
413
414
415
      rval = locate_points(&pos[0], verts.size(), ents, params, is_inside, rel_iter_tol, abs_iter_tol, inside_tol);

        // only call this if I'm the top-level function, since it resets the last time called
      if (i_initialized) 
        myTimes.slTimes[SpatialLocatorTimes::SRC_SEARCH] =  myTimer.time_elapsed();

      return rval;
416
417
    }

418
    ErrorCode SpatialLocator::locate_points(const double *pos, int num_points,
419
                                            EntityHandle *ents, double *params, int *is_inside,
420
                                            const double /* rel_iter_tol */, const double abs_iter_tol,
421
                                            const double inside_tol)
422
    {
423
424
425
426
427
428
429
      bool i_initialized = false;
      if (!timerInitialized) {
        myTimer.time_elapsed();
        timerInitialized = true;
        i_initialized = true;
      }

430
      /*
431
432
      double tmp_abs_iter_tol = abs_iter_tol;
      if (rel_iter_tol && !tmp_abs_iter_tol) {
433
          // relative epsilon given, translate to absolute epsilon using box dimensions
434
        tmp_abs_iter_tol = rel_iter_tol * localBox.diagonal_length();
435
      }
436
      */
437

438
439
440
441
      if (elemEval && myTree->get_eval() != elemEval)
        myTree->set_eval(elemEval);
      
      ErrorCode rval = MB_SUCCESS;
442
443
      for (int i = 0; i < num_points; i++) {
        int i3 = 3*i;
444
445
446
447
        ErrorCode tmp_rval = myTree->point_search(pos+i3, ents[i], abs_iter_tol, inside_tol, NULL, NULL, 
                                                  (CartVect*)(params+i3));
        if (MB_SUCCESS != tmp_rval) {
          rval = tmp_rval;
448
449
          continue;
        }
450
451

        if (debug && !ents[i]) {
452
453
454
          std::cout << "Point " << i << " not found; point: (" 
                    << pos[i3] << "," << pos[i3+1] << "," << pos[i3+2] << ")" << std::endl;
        }
455

456
457
458
        if (is_inside) is_inside[i] = (ents[i] ? true : false);
      }
      
459
460
461
462
        // only call this if I'm the top-level function, since it resets the last time called
      if (i_initialized) 
        myTimes.slTimes[SpatialLocatorTimes::SRC_SEARCH] =  myTimer.time_elapsed();

463
      return rval;
464
465
    }
    
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        /* Count the number of located points in locTable
         * Return the number of entries in locTable that have non-zero entity handles, which
         * represents the number of points in targetEnts that were inside one element in sourceEnts
         *
         */
    int SpatialLocator::local_num_located() 
    {
      int num_located = locTable.get_n() - std::count(locTable.vul_rd, locTable.vul_rd+locTable.get_n(), 0);
      if (num_located != (int)locTable.get_n()) {
        unsigned long *nl = std::find(locTable.vul_rd, locTable.vul_rd+locTable.get_n(), 0);
        if (nl) {
          int idx = nl - locTable.vul_rd;
          if (idx) {}
        }
      }
      return num_located;
    }

        /* Count the number of located points in parLocTable
         * Return the number of entries in parLocTable that have a non-negative index in on a remote
         * proc in parLocTable, which gives the number of points located in at least one element in a
         * remote proc's sourceEnts.
         */
    int SpatialLocator::remote_num_located()
    {
      int located = 0;
      for (unsigned int i = 0; i < parLocTable.get_n(); i++)
        if (parLocTable.vi_rd[2*i] != -1) located++;
      return located;
    }
496
497
} // namespace moab