Ocean
Loading...
Searching...
No Matches
KdTree.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8#ifndef META_OCEAN_BASE_KD_TREE_H
9#define META_OCEAN_BASE_KD_TREE_H
10
11#include "ocean/base/Base.h"
12#include "ocean/base/DataType.h"
13#include "ocean/base/Median.h"
15
16#include <limits>
17
18namespace Ocean
19{
20
21/**
22 * This class implements a k-d tree.
23 * In general, k-d trees should be applied for problems with small dimensions only as the performance benefit decreases with increasing dimension significantly.<br>
24 * That means for the number of nodes (n) and the dimension (k) the following should hold: n >> 2^k.
25 * @tparam T The data type of one element for all dimensions.
26 * @ingroup base
27 */
28template <typename T>
29class KdTree
30{
31 protected:
32
33 /**
34 * This class defines a node of the k-d tree.
35 */
36 class Node
37 {
38 public:
39
40 /**
41 * Creates an empty node.
42 */
43 Node() = default;
44
45 /**
46 * Destructs a node.
47 */
48 ~Node() = default;
49
50 /**
51 * Creates a new node with given value.
52 * @param value Node value
53 */
54 explicit inline Node(const T* value);
55
56 /**
57 * Returns the left child of this node.
58 * @return Left child
59 */
60 inline Node* left();
61
62 /**
63 * Returns the right child of this node.
64 * @return Right child
65 */
66 inline Node* right();
67
68 /**
69 * Returns value of this node.
70 * @return Node value
71 */
72 inline const T* value();
73
74 /**
75 * Sets the left child of this node.
76 * @param left The left child
77 */
78 inline void setLeft(std::unique_ptr<Node>&& left);
79
80 /**
81 * Sets the right child of this node.
82 * @param right The right child
83 */
84 inline void setRight(std::unique_ptr<Node>&& right);
85
86 /**
87 * Returns whether this node holds a valid value.
88 * @return True, if so
89 */
90 explicit inline operator bool() const;
91
92 protected:
93
94 /// Node value.
95 const T* value_ = nullptr;
96
97 /// Left node.
98 std::unique_ptr<Node> left_ = nullptr;
99
100 /// Right node.
101 std::unique_ptr<Node> right_ = nullptr;
102 };
103
104 /**
105 * Definition of a vector holding single elements.
106 */
107 typedef std::vector<T> Elements;
108
109 /**
110 * Definition of a vector holding single pointers.
111 */
112 typedef std::vector<const T*> Pointers;
113
114 public:
115
116 /**
117 * Creates a new k-d tree.
118 * @param dimension Number of dimensions the tree will have, with range [1, infinity)
119 */
120 explicit inline KdTree(const unsigned int dimension);
121
122 /**
123 * Destructs a k-d tree.
124 */
125 ~KdTree() = default;
126
127 /**
128 * Inserts a set of values to this empty tree.
129 * Beware: Adding elements to an already existing tree with nodes is not supported.
130 * @param values The values to be added
131 * @param number The number of elements to be added, with range [0, infinity)
132 * @return True, if succeeded
133 */
134 bool insert(const T** values, const size_t number);
135
136 /**
137 * Applies a nearest neighbor search for a given value.
138 * @param value The value to be searched
139 * @param distance Resulting minimal distance
140 * @return Resulting nearest neighbor
141 */
142 const T* nearestNeighbor(const T* value, typename SquareValueTyper<T>::Type& distance) const;
143
144 /**
145 * Applies a radius search for neighbors of a given value.
146 * Beware: Function offers performance boost over brute force search only if radius is so small that relatively few values are returned.
147 * @param value The value to be searched
148 * @param radius The neighborhood radius, with range [0, infinity)
149 * @param values Found values within radius distance from a given value
150 * @param maxValues Limit number of returned values
151 * @return Number of returned values
152 */
153 size_t radiusSearch(const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues) const;
154
155 /**
156 * Returns the dimension of the tree's values.
157 * @return The tree's dimension
158 */
159 inline unsigned int dimension() const;
160
161 /**
162 * Returns the number of tree nodes.
163 * @return Tree size
164 */
165 inline size_t size() const;
166
167 protected:
168
169 /**
170 * Inserts a set of values to a given parent as left children.
171 * @param parent The parent node
172 * @param values The values to be added
173 * @param number The number of elements to be added
174 * @param depth Current depth of the new nodes
175 */
176 void insertLeft(Node& parent, const T** values, const size_t number, const unsigned int depth);
177
178 /**
179 * Inserts a set of values to a given parent as right children.
180 * @param parent The parent node
181 * @param values The values to be added
182 * @param number The number of elements to be added
183 * @param depth Current depth of the new nodes
184 */
185 void insertRight(Node& parent, const T** values, const size_t number, const unsigned int depth);
186
187 /**
188 * Applies a nearest neighbor search for a given node and value.
189 * @param node Reference node
190 * @param value The value to be searched
191 * @param nearest Resulting nearest value
192 * @param distance Resulting minimal distance
193 * @param index The index of the dimension for the recent search, with range [0, dimension())
194 */
195 void nearestNeighbor(Node& node, const T* value, const T*& nearest, typename SquareValueTyper<T>::Type& distance, const unsigned int index) const;
196
197 /**
198 * Applies a radius search for neighbors of a given value.
199 * @param node Reference node
200 * @param value The value to be searched
201 * @param radius The neighborhood radius, with range [0, infinity)
202 * @param values Found values within radius distance from a given value
203 * @param maxValues Limit number of returned values
204 * @param index The index of the dimension for the recent search, with range [0, dimension())
205 * @return Number of returned values
206 */
207 size_t radiusSearch(Node& node, const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues, const unsigned int index) const;
208
209 /**
210 * Returns the median for a given set of values.
211 * @param values The values to return the median for
212 * @param number The number of values
213 * @param index Dimension index
214 * @return Resulting median
215 */
216 static T median(const T** values, const size_t number, const unsigned int index);
217
218 /**
219 * Distributes a given set of values into two subset according to the median.
220 * @param values The values to distribute
221 * @param number The number of values
222 * @param index Dimension index to be used for distribution
223 * @param medianValue Resulting median value
224 * @param leftValues Resulting left values
225 * @param rightValues Resulting right values
226 */
227 static void distribute(const T** values, const size_t number, const unsigned int index, const T*& medianValue, Pointers& leftValues, Pointers& rightValues);
228
229 /**
230 * Determines the square distance between two values.
231 * @param first The first value
232 * @param second The second value
233 * @return Square distance
234 */
235 inline typename SquareValueTyper<T>::Type determineSquareDistance(const T* first, const T* second) const;
236
237 protected:
238
239 /// Root node of this tree.
240 std::unique_ptr<Node> root_;
241
242 /// Number of nodes.
243 size_t size_ = 0;
244
245 /// Number of dimensions.
246 const unsigned int dimension_ = 0u;
247};
248
249template <typename T>
250inline KdTree<T>::Node::Node(const T* value) :
251 value_(value)
252{
253 // nothing to do here
254}
255
256template <typename T>
258{
259 return left_.get();
260}
261
262template <typename T>
264{
265 return right_.get();
266}
267
268template <typename T>
269inline const T* KdTree<T>::Node::value()
270{
271 return value_;
272}
273
274template <typename T>
275inline void KdTree<T>::Node::setLeft(std::unique_ptr<Node>&& left)
276{
277 ocean_assert(!left_);
278 left_ = std::move(left);
279}
280
281template <typename T>
282inline void KdTree<T>::Node::setRight(std::unique_ptr<Node>&& right)
283{
284 ocean_assert(!right_);
285 right_ = std::move(right);
286}
287
288template <typename T>
289inline KdTree<T>::Node::operator bool() const
290{
291 return value_ != nullptr;
292}
293
294template <typename T>
295inline KdTree<T>::KdTree(const unsigned int dimension) :
297{
298 ocean_assert(dimension >= 1u);
299}
300
301template <typename T>
302bool KdTree<T>::insert(const T** values, const size_t number)
303{
304 if (number == 0)
305 {
306 return true;
307 }
308
309 if (root_)
310 {
311 return false;
312 }
313
314 ocean_assert(values);
315
316 const T* median = nullptr;
317 Pointers left, right;
318
319 distribute(values, number, 0u, median, left, right);
320
321 root_ = std::make_unique<Node>(median);
322
323 if (!left.empty())
324 {
325 insertLeft(*root_, left.data(), left.size(), 1u);
326 }
327
328 if (!right.empty())
329 {
330 insertRight(*root_, right.data(), right.size(), 1u);
331 }
332
333 size_ = number;
334
335 return true;
336}
337
338template <typename T>
339const T* KdTree<T>::nearestNeighbor(const T* value, typename SquareValueTyper<T>::Type& distance) const
340{
341 ocean_assert(value);
342
343 distance = std::numeric_limits<T>::max();
344
345 if (!root_)
346 {
347 return nullptr;
348 }
349
350 const T* nearest = nullptr;
351
352 nearestNeighbor(*root_, value, nearest, distance, 0u);
353 return nearest;
354}
355
356template <typename T>
357size_t KdTree<T>::radiusSearch(const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues) const
358{
359 ocean_assert(value);
360 ocean_assert(values);
361
362 if (!root_ || maxValues == 0)
363 {
364 return 0;
365 }
366
367 size_t found = radiusSearch(*root_, value, radius, values, maxValues, 0u);
368
369 ocean_assert(found <= maxValues);
370
371 return found;
372}
373
374template <typename T>
375inline unsigned int KdTree<T>::dimension() const
376{
377 return dimension_;
378}
379
380template <typename T>
381inline size_t KdTree<T>::size() const
382{
383 return size_;
384}
385
386template <typename T>
387void KdTree<T>::insertLeft(Node& parent, const T** values, const size_t number, const unsigned int depth)
388{
389#ifdef OCEAN_DEBUG
390 ocean_assert(parent);
391 ocean_assert(values && number > 0);
392
393 for (size_t n = 0; n < number; ++n)
394 {
395 ocean_assert(values[n][(depth - 1) % dimension_] <= parent.value()[(depth - 1) % dimension_]);
396 }
397#endif
398
399 ocean_assert(!parent.left());
400
401 const unsigned int index = depth % dimension_;
402
403 const T* median = nullptr;
404 Pointers left, right;
405
406 distribute(values, number, index, median, left, right);
407
408 parent.setLeft(std::make_unique<Node>(median));
409
410 if (!left.empty())
411 {
412 insertLeft(*parent.left(), left.data(), left.size(), depth + 1);
413 }
414
415 if (!right.empty())
416 {
417 insertRight(*parent.left(), right.data(), right.size(), depth + 1);
418 }
419}
420
421template <typename T>
422void KdTree<T>::insertRight(Node& parent, const T** values, const size_t number, const unsigned int depth)
423{
424#ifdef OCEAN_DEBUG
425 ocean_assert(parent);
426 ocean_assert(values && number > 0);
427
428 for (size_t n = 0; n < number; ++n)
429 {
430 ocean_assert(parent.value()[(depth - 1) % dimension_] < values[n][(depth - 1) % dimension_]);
431 }
432#endif
433
434 ocean_assert(!parent.right());
435
436 const unsigned int index = depth % dimension_;
437
438 const T* median = nullptr;
439 Pointers left, right;
440
441 distribute(values, number, index, median, left, right);
442
443 parent.setRight(std::make_unique<Node>(median));
444
445 if (!left.empty())
446 {
447 insertLeft(*parent.right(), left.data(), left.size(), depth + 1);
448 }
449
450 if (!right.empty())
451 {
452 insertRight(*parent.right(), right.data(), right.size(), depth + 1);
453 }
454}
455
456template <typename T>
457void KdTree<T>::nearestNeighbor(Node& node, const T* value, const T*& nearest, typename SquareValueTyper<T>::Type& distance, const unsigned int index) const
458{
459 ocean_assert(node && value);
460 ocean_assert(index < dimension_);
461
462 const typename SquareValueTyper<T>::Type localDistance = determineSquareDistance(value, node.value());
463 if (localDistance < distance)
464 {
465 distance = localDistance;
466 nearest = node.value();
467 }
468
469 const unsigned int nextIndex = (index + 1u) % dimension_;
470
471 // depth-first-search
472 if (value[index] <= node.value()[index])
473 {
474 if (node.left())
475 {
476 nearestNeighbor(*node.left(), value, nearest, distance, nextIndex);
477 }
478
479 // check the neighboring branch not covered by the depth-first-search
480 if (node.right() && sqr(value[index] - node.value()[index]) < distance)
481 {
482 nearestNeighbor(*node.right(), value, nearest, distance, nextIndex);
483 }
484 }
485 else
486 {
487 ocean_assert(value[index] > node.value()[index]);
488
489 if (node.right())
490 {
491 nearestNeighbor(*node.right(), value, nearest, distance, nextIndex);
492 }
493
494 // check the neighboring branch not covered by the depth-first-search
495 if (node.left() && sqr(value[index] - node.value()[index]) < distance)
496 {
497 nearestNeighbor(*node.left(), value, nearest, distance, nextIndex);
498 }
499 }
500}
501
502template <typename T>
503size_t KdTree<T>::radiusSearch(Node& node, const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues, const unsigned int index) const
504{
505 ocean_assert(node && value && values);
506 ocean_assert(index < dimension_);
507
508 if (maxValues == 0)
509 {
510 return 0;
511 }
512
513 T const ** current = values;
514 T const ** const end = values + maxValues;
515
516 const typename SquareValueTyper<T>::Type localDistance = determineSquareDistance(value, node.value());
517 if (localDistance <= radius)
518 {
519 *current++ = node.value();
520
521 if (current == end)
522 {
523 return maxValues;
524 }
525 }
526
527 const unsigned int nextIndex = (index + 1u) % dimension_;
528
529 // depth-first-search
530 if (value[index] <= node.value()[index])
531 {
532 if (node.left())
533 {
534 current += radiusSearch(*node.left(), value, radius, current, end - current, nextIndex);
535 if (current == end)
536 {
537 return maxValues;
538 }
539 }
540
541 // check the neighboring branch not covered by the depth-first-search
542 if (node.right() && sqr(value[index] - node.value()[index]) <= radius)
543 {
544 current += radiusSearch(*node.right(), value, radius, current, end - current, nextIndex);
545 }
546 }
547 else
548 {
549 ocean_assert(value[index] > node.value()[index]);
550
551 if (node.right())
552 {
553 current += radiusSearch(*node.right(), value, radius, current, end - current, nextIndex);
554 if (current == end)
555 {
556 return maxValues;
557 }
558 }
559
560 // check the neighboring branch not covered by the depth-first-search
561 if (node.left() && sqr(value[index] - node.value()[index]) <= radius)
562 {
563 current += radiusSearch(*node.left(), value, radius, current, end - current, nextIndex);
564 }
565 }
566
567 ocean_assert(current <= end);
568
569 return current - values;
570}
571
572template <typename T>
573T KdTree<T>::median(const T** values, const size_t number, const unsigned int index)
574{
575 ocean_assert(values && number > 0);
576
577 Elements elements(number);
578 for (size_t n = 0; n < number; ++n)
579 {
580 elements[n] = values[n][index];
581 }
582
583 return Median::median<T>((T*)elements.data(), elements.size());
584}
585
586template <typename T>
587void KdTree<T>::distribute(const T** values, const size_t number, const unsigned int index, const T*& medianValue, Pointers& leftValues, Pointers& rightValues)
588{
589 ocean_assert(values && number > 0);
590
591 const T middle = median(values, number, index);
592
593 leftValues.reserve(number / 2);
594 rightValues.reserve(number / 2);
595
596 ocean_assert(leftValues.empty() && rightValues.empty());
597
598 bool medianFound = false;
599
600 for (size_t n = 0; n < number; ++n)
601 {
602 if (values[n][index] < middle || (values[n][index] == middle && medianFound))
603 {
604 leftValues.push_back(values[n]);
605 }
606 else if (middle < values[n][index])
607 {
608 rightValues.push_back(values[n]);
609 }
610 else
611 {
612 ocean_assert(values[n][index] == middle);
613 ocean_assert(!medianFound);
614
615 medianValue = values[n];
616 medianFound = true;
617 }
618 }
619
620 ocean_assert(medianFound);
621}
622
623template <typename T>
624inline typename SquareValueTyper<T>::Type KdTree<T>::determineSquareDistance(const T* first, const T* second) const
625{
626 ocean_assert(first && second);
627
628 typename SquareValueTyper<T>::Type ssd = 0;
629
630 for (size_t n = 0; n < dimension_; ++n)
631 {
632 ssd += sqr(first[n] - second[n]);
633 }
634
635 return ssd;
636}
637
638}
639
640#endif // META_OCEAN_BASE_KD_TREE_H
This class defines a node of the k-d tree.
Definition KdTree.h:37
const T * value_
Node value.
Definition KdTree.h:95
Node * left()
Returns the left child of this node.
Definition KdTree.h:257
std::unique_ptr< Node > left_
Left node.
Definition KdTree.h:98
Node * right()
Returns the right child of this node.
Definition KdTree.h:263
~Node()=default
Destructs a node.
void setRight(std::unique_ptr< Node > &&right)
Sets the right child of this node.
Definition KdTree.h:282
void setLeft(std::unique_ptr< Node > &&left)
Sets the left child of this node.
Definition KdTree.h:275
Node()=default
Creates an empty node.
std::unique_ptr< Node > right_
Right node.
Definition KdTree.h:101
const T * value()
Returns value of this node.
Definition KdTree.h:269
This class implements a k-d tree.
Definition KdTree.h:30
std::vector< const T * > Pointers
Definition of a vector holding single pointers.
Definition KdTree.h:112
const T * nearestNeighbor(const T *value, typename SquareValueTyper< T >::Type &distance) const
Applies a nearest neighbor search for a given value.
Definition KdTree.h:339
const unsigned int dimension_
Number of dimensions.
Definition KdTree.h:246
std::vector< T > Elements
Definition of a vector holding single elements.
Definition KdTree.h:107
std::unique_ptr< Node > root_
Root node of this tree.
Definition KdTree.h:240
static T median(const T **values, const size_t number, const unsigned int index)
Returns the median for a given set of values.
Definition KdTree.h:573
~KdTree()=default
Destructs a k-d tree.
static void distribute(const T **values, const size_t number, const unsigned int index, const T *&medianValue, Pointers &leftValues, Pointers &rightValues)
Distributes a given set of values into two subset according to the median.
Definition KdTree.h:587
KdTree(const unsigned int dimension)
Creates a new k-d tree.
Definition KdTree.h:295
size_t size() const
Returns the number of tree nodes.
Definition KdTree.h:381
void insertLeft(Node &parent, const T **values, const size_t number, const unsigned int depth)
Inserts a set of values to a given parent as left children.
Definition KdTree.h:387
bool insert(const T **values, const size_t number)
Inserts a set of values to this empty tree.
Definition KdTree.h:302
size_t radiusSearch(const T *value, const typename SquareValueTyper< T >::Type radius, const T **values, const size_t maxValues) const
Applies a radius search for neighbors of a given value.
Definition KdTree.h:357
size_t size_
Number of nodes.
Definition KdTree.h:243
SquareValueTyper< T >::Type determineSquareDistance(const T *first, const T *second) const
Determines the square distance between two values.
Definition KdTree.h:624
unsigned int dimension() const
Returns the dimension of the tree's values.
Definition KdTree.h:375
void insertRight(Node &parent, const T **values, const size_t number, const unsigned int depth)
Inserts a set of values to a given parent as right children.
Definition KdTree.h:422
T Type
Definition of the data type for the square value.
Definition DataType.h:132
unsigned int sqr(const char value)
Returns the square value of a given value.
Definition base/Utilities.h:1029
The namespace covering the entire Ocean framework.
Definition Accessor.h:15