@@ -165,12 +165,14 @@ class search_nearest_topological {
165
165
v,
166
166
node->data .branch .left_min ,
167
167
node->data .branch .left_max ,
168
- node->data .branch .split_dim );
168
+ node->data .branch .split_dim ,
169
+ euclidean_space_tag{});
169
170
scalar_type const d2 = metric_ (
170
171
v,
171
172
node->data .branch .right_min ,
172
173
node->data .branch .right_max ,
173
- node->data .branch .split_dim );
174
+ node->data .branch .split_dim ,
175
+ euclidean_space_tag{});
174
176
node_type const * node_1st;
175
177
node_type const * node_2nd;
176
178
scalar_type new_offset;
@@ -210,27 +212,28 @@ class search_nearest_topological {
210
212
Visitor_& visitor_;
211
213
};
212
214
213
- // ! \brief A functor that provides range searches for Euclidean spaces. Query
214
- // ! time is bounded by O(n^(1-1/dimension)+k).
215
+ // ! \brief A functor that provides range searches for both Euclidean and
216
+ // ! topological spaces. Query time is bounded by O(n^(1-1/dimension)+k).
215
217
// ! \details Many tree nodes are excluded by checking if they intersect with the
216
218
// ! box of the query. We don't store the bounding box of each node but calculate
217
219
// ! them at run time. This slows down search_box in favor of having faster
218
220
// ! nearest neighbor searches.
219
221
template <typename SpaceWrapper_, typename Metric_, typename Index_>
220
- class search_box_euclidean {
221
- public:
222
- // TODO Perhaps we can support it for both topological and Euclidean spaces.
223
- static_assert (
224
- std::is_same_v<typename Metric_::space_category, euclidean_space_tag>,
225
- " SEARCH_BOX_ONLY_SUPPORTED_FOR_EUCLIDEAN_SPACES" );
222
+ class search_box {
223
+ using space_category = typename Metric_::space_category;
224
+ static constexpr bool is_euclidean_space_v =
225
+ std::is_same_v<space_category, euclidean_space_tag>;
226
226
227
+ public:
227
228
using index_type = Index_;
228
229
using scalar_type = typename SpaceWrapper_::scalar_type;
229
230
static size_t constexpr dim = SpaceWrapper_::dim;
230
231
using box_type = box<scalar_type, dim>;
231
232
using box_map_type = box_map<scalar_type const , dim>;
233
+ using node_type = typename kd_tree_space_tag_traits<
234
+ space_category>::template node_type<index_type, scalar_type>;
232
235
233
- inline search_box_euclidean (
236
+ inline search_box (
234
237
SpaceWrapper_ space,
235
238
Metric_ metric,
236
239
std::vector<index_type> const & indices,
@@ -245,13 +248,12 @@ class search_box_euclidean {
245
248
idxs_(idxs) {}
246
249
247
250
// ! \brief Range search starting from \p node.
248
- template <typename Node_>
249
- inline void operator ()(Node_ const * const node) {
251
+ inline void operator ()(node_type const * const node) {
250
252
if (node->is_leaf ()) {
251
253
auto begin = indices_.begin () + node->data .leaf .begin_idx ;
252
254
auto const end = indices_.begin () + node->data .leaf .end_idx ;
253
255
for (; begin < end; ++begin) {
254
- if (query_. contains (space_[ *begin] )) {
256
+ if (contains (*begin)) {
255
257
idxs_.push_back (*begin);
256
258
}
257
259
}
@@ -265,7 +267,7 @@ class search_box_euclidean {
265
267
// down the left node.
266
268
if (query_.contains (box_)) {
267
269
report_node (node->left );
268
- } else if (query_. min (split_dim) <= node-> data . branch . left_max ) {
270
+ } else if (intersects_left (split_dim, node) ) {
269
271
operator ()(node->left );
270
272
}
271
273
@@ -276,7 +278,7 @@ class search_box_euclidean {
276
278
// Same as the left side.
277
279
if (query_.contains (box_)) {
278
280
report_node (node->right );
279
- } else if (query_. max (split_dim) >= node-> data . branch . right_min ) {
281
+ } else if (intersects_right (split_dim, node) ) {
280
282
operator ()(node->right );
281
283
}
282
284
@@ -285,9 +287,61 @@ class search_box_euclidean {
285
287
}
286
288
287
289
private:
290
+ // TODO We could add an extra class layer to the box_base, box, and box_map
291
+ // hierarchy, to support topological boxes. However, this is a lot more
292
+ // code/work and we currently only use this feature here and in a unit test.
293
+ bool contains (scalar_type const * const p) const {
294
+ for (size_t i = 0 ; i < query_.size (); ++i) {
295
+ if (metric_ (
296
+ p[i],
297
+ query_.min (i),
298
+ query_.max (i),
299
+ static_cast <int >(i),
300
+ topological_space_tag{}) > scalar_type (0.0 )) {
301
+ return false ;
302
+ }
303
+ }
304
+ return true ;
305
+ }
306
+
307
+ bool contains (index_type const idx) const {
308
+ if constexpr (is_euclidean_space_v) {
309
+ return query_.contains (space_[idx]);
310
+ } else {
311
+ return contains (space_[idx]);
312
+ }
313
+ }
314
+
315
+ bool contains () const {
316
+ if constexpr (is_euclidean_space_v) {
317
+ return query_.contains (box_);
318
+ } else {
319
+ return contains (box_.min ()) && contains (box_.max ());
320
+ }
321
+ }
322
+
323
+ bool intersects_left (
324
+ size_t const split_dim, node_type const * const node) const {
325
+ if constexpr (is_euclidean_space_v) {
326
+ return query_.min (split_dim) <= node->data .branch .left_max ;
327
+ } else {
328
+ return query_.min (split_dim) <= node->data .branch .left_max ||
329
+ query_.max (split_dim) >= node->data .branch .left_min ;
330
+ }
331
+ }
332
+
333
+ bool intersects_right (
334
+ size_t const split_dim, node_type const * const node) const {
335
+ if constexpr (is_euclidean_space_v) {
336
+ return query_.max (split_dim) >= node->data .branch .right_min ;
337
+ } else {
338
+ return query_.max (split_dim) >= node->data .branch .right_min ||
339
+ query_.min (split_dim) <= node->data .branch .right_max ;
340
+ }
341
+ }
342
+
288
343
// ! \brief Reports all indices contained by \p node.
289
- template <typename Node_>
290
- inline void report_node (Node_ const * const node) const {
344
+ inline void report_node (node_type const * const node) const {
291
345
index_type begin;
292
346
index_type end;
293
347
@@ -309,17 +363,15 @@ class search_box_euclidean {
309
363
std::back_inserter (idxs_));
310
364
}
311
365
312
- template <typename Node_>
313
- inline index_type report_left (Node_ const * const node) const {
366
+ inline index_type report_left (node_type const * const node) const {
314
367
if (node->is_leaf ()) {
315
368
return node->data .leaf .begin_idx ;
316
369
} else {
317
370
return report_left (node->left );
318
371
}
319
372
}
320
373
321
- template <typename Node_>
322
- inline index_type report_right (Node_ const * const node) const {
374
+ inline index_type report_right (node_type const * const node) const {
323
375
if (node->is_leaf ()) {
324
376
return node->data .leaf .end_idx ;
325
377
} else {
0 commit comments