Ocean
KdTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #ifndef META_OCEAN_BASE_KD_TREE_H
9 #define META_OCEAN_BASE_KD_TREE_H
10 
11 #include "ocean/base/Base.h"
12 #include "ocean/base/DataType.h"
13 #include "ocean/base/Median.h"
14 #include "ocean/base/Utilities.h"
15 
16 #include <limits>
17 
18 namespace Ocean
19 {
20 
21 /**
22  * This class implements a k-d tree.
23  * In general, k-d trees should be applied for problems with small dimensions only as the performance benefit decreases with increasing dimension significantly.<br>
24  * That means for the number of nodes (n) and the dimension (k) the following should hold: n >> 2^k.
25  * @tparam T The data type of one element for all dimensions.
26  * @ingroup base
27  */
28 template <typename T>
29 class KdTree
30 {
31  protected:
32 
33  /**
34  * This class defines a node of the k-d tree.
35  */
36  class Node
37  {
38  public:
39 
40  /**
41  * Creates an empty node.
42  */
43  Node() = default;
44 
45  /**
46  * Destructs a node.
47  */
48  ~Node() = default;
49 
50  /**
51  * Creates a new node with given value.
52  * @param value Node value
53  */
54  explicit inline Node(const T* value);
55 
56  /**
57  * Returns the left child of this node.
58  * @return Left child
59  */
60  inline Node* left();
61 
62  /**
63  * Returns the right child of this node.
64  * @return Right child
65  */
66  inline Node* right();
67 
68  /**
69  * Returns value of this node.
70  * @return Node value
71  */
72  inline const T* value();
73 
74  /**
75  * Sets the left child of this node.
76  * @param left The left child
77  */
78  inline void setLeft(std::unique_ptr<Node>&& left);
79 
80  /**
81  * Sets the right child of this node.
82  * @param right The right child
83  */
84  inline void setRight(std::unique_ptr<Node>&& right);
85 
86  /**
87  * Returns whether this node holds a valid value.
88  * @return True, if so
89  */
90  explicit inline operator bool() const;
91 
92  protected:
93 
94  /// Node value.
95  const T* value_ = nullptr;
96 
97  /// Left node.
98  std::unique_ptr<Node> left_ = nullptr;
99 
100  /// Right node.
101  std::unique_ptr<Node> right_ = nullptr;
102  };
103 
104  /**
105  * Definition of a vector holding single elements.
106  */
107  typedef std::vector<T> Elements;
108 
109  /**
110  * Definition of a vector holding single pointers.
111  */
112  typedef std::vector<const T*> Pointers;
113 
114  public:
115 
116  /**
117  * Creates a new k-d tree.
118  * @param dimension Number of dimensions the tree will have, with range [1, infinity)
119  */
120  explicit inline KdTree(const unsigned int dimension);
121 
122  /**
123  * Destructs a k-d tree.
124  */
125  ~KdTree() = default;
126 
127  /**
128  * Inserts a set of values to this empty tree.
129  * Beware: Adding elements to an already existing tree with nodes is not supported.
130  * @param values The values to be added
131  * @param number The number of elements to be added, with range [0, infinity)
132  * @return True, if succeeded
133  */
134  bool insert(const T** values, const size_t number);
135 
136  /**
137  * Applies a nearest neighbor search for a given value.
138  * @param value The value to be searched
139  * @param distance Resulting minimal distance
140  * @return Resulting nearest neighbor
141  */
142  const T* nearestNeighbor(const T* value, typename SquareValueTyper<T>::Type& distance) const;
143 
144  /**
145  * Applies a radius search for neighbors of a given value.
146  * Beware: Function offers performance boost over brute force search only if radius is so small that relatively few values are returned.
147  * @param value The value to be searched
148  * @param radius The neighborhood radius, with range [0, infinity)
149  * @param values Found values within radius distance from a given value
150  * @param maxValues Limit number of returned values
151  * @return Number of returned values
152  */
153  size_t radiusSearch(const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues) const;
154 
155  /**
156  * Returns the dimension of the tree's values.
157  * @return The tree's dimension
158  */
159  inline unsigned int dimension() const;
160 
161  /**
162  * Returns the number of tree nodes.
163  * @return Tree size
164  */
165  inline size_t size() const;
166 
167  protected:
168 
169  /**
170  * Inserts a set of values to a given parent as left children.
171  * @param parent The parent node
172  * @param values The values to be added
173  * @param number The number of elements to be added
174  * @param depth Current depth of the new nodes
175  */
176  void insertLeft(Node& parent, const T** values, const size_t number, const unsigned int depth);
177 
178  /**
179  * Inserts a set of values to a given parent as right children.
180  * @param parent The parent node
181  * @param values The values to be added
182  * @param number The number of elements to be added
183  * @param depth Current depth of the new nodes
184  */
185  void insertRight(Node& parent, const T** values, const size_t number, const unsigned int depth);
186 
187  /**
188  * Applies a nearest neighbor search for a given node and value.
189  * @param node Reference node
190  * @param value The value to be searched
191  * @param nearest Resulting nearest value
192  * @param distance Resulting minimal distance
193  * @param index The index of the dimension for the recent search, with range [0, dimension())
194  */
195  void nearestNeighbor(Node& node, const T* value, const T*& nearest, typename SquareValueTyper<T>::Type& distance, const unsigned int index) const;
196 
197  /**
198  * Applies a radius search for neighbors of a given value.
199  * @param node Reference node
200  * @param value The value to be searched
201  * @param radius The neighborhood radius, with range [0, infinity)
202  * @param values Found values within radius distance from a given value
203  * @param maxValues Limit number of returned values
204  * @param index The index of the dimension for the recent search, with range [0, dimension())
205  * @return Number of returned values
206  */
207  size_t radiusSearch(Node& node, const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues, const unsigned int index) const;
208 
209  /**
210  * Returns the median for a given set of values.
211  * @param values The values to return the median for
212  * @param number The number of values
213  * @param index Dimension index
214  * @return Resulting median
215  */
216  static T median(const T** values, const size_t number, const unsigned int index);
217 
218  /**
219  * Distributes a given set of values into two subset according to the median.
220  * @param values The values to distribute
221  * @param number The number of values
222  * @param index Dimension index to be used for distribution
223  * @param medianValue Resulting median value
224  * @param leftValues Resulting left values
225  * @param rightValues Resulting right values
226  */
227  static void distribute(const T** values, const size_t number, const unsigned int index, const T*& medianValue, Pointers& leftValues, Pointers& rightValues);
228 
229  /**
230  * Determines the square distance between two values.
231  * @param first The first value
232  * @param second The second value
233  * @return Square distance
234  */
235  inline typename SquareValueTyper<T>::Type determineSquareDistance(const T* first, const T* second) const;
236 
237  protected:
238 
239  /// Root node of this tree.
240  std::unique_ptr<Node> root_;
241 
242  /// Number of nodes.
243  size_t size_ = 0;
244 
245  /// Number of dimensions.
246  const unsigned int dimension_ = 0u;
247 };
248 
249 template <typename T>
250 inline KdTree<T>::Node::Node(const T* value) :
251  value_(value)
252 {
253  // nothing to do here
254 }
255 
256 template <typename T>
258 {
259  return left_.get();
260 }
261 
262 template <typename T>
264 {
265  return right_.get();
266 }
267 
268 template <typename T>
269 inline const T* KdTree<T>::Node::value()
270 {
271  return value_;
272 }
273 
274 template <typename T>
275 inline void KdTree<T>::Node::setLeft(std::unique_ptr<Node>&& left)
276 {
277  ocean_assert(!left_);
278  left_ = std::move(left);
279 }
280 
281 template <typename T>
282 inline void KdTree<T>::Node::setRight(std::unique_ptr<Node>&& right)
283 {
284  ocean_assert(!right_);
285  right_ = std::move(right);
286 }
287 
288 template <typename T>
289 inline KdTree<T>::Node::operator bool() const
290 {
291  return value_ != nullptr;
292 }
293 
294 template <typename T>
295 inline KdTree<T>::KdTree(const unsigned int dimension) :
297 {
298  ocean_assert(dimension >= 1u);
299 }
300 
301 template <typename T>
302 bool KdTree<T>::insert(const T** values, const size_t number)
303 {
304  if (number == 0)
305  {
306  return true;
307  }
308 
309  if (root_)
310  {
311  return false;
312  }
313 
314  ocean_assert(values);
315 
316  const T* median = nullptr;
317  Pointers left, right;
318 
319  distribute(values, number, 0u, median, left, right);
320 
321  root_ = std::make_unique<Node>(median);
322 
323  if (!left.empty())
324  {
325  insertLeft(*root_, left.data(), left.size(), 1u);
326  }
327 
328  if (!right.empty())
329  {
330  insertRight(*root_, right.data(), right.size(), 1u);
331  }
332 
333  size_ = number;
334 
335  return true;
336 }
337 
338 template <typename T>
339 const T* KdTree<T>::nearestNeighbor(const T* value, typename SquareValueTyper<T>::Type& distance) const
340 {
341  ocean_assert(value);
342 
343  distance = std::numeric_limits<T>::max();
344 
345  if (!root_)
346  {
347  return nullptr;
348  }
349 
350  const T* nearest = nullptr;
351 
352  nearestNeighbor(*root_, value, nearest, distance, 0u);
353  return nearest;
354 }
355 
356 template <typename T>
357 size_t KdTree<T>::radiusSearch(const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues) const
358 {
359  ocean_assert(value);
360  ocean_assert(values);
361 
362  if (!root_ || maxValues == 0)
363  {
364  return 0;
365  }
366 
367  size_t found = radiusSearch(*root_, value, radius, values, maxValues, 0u);
368 
369  ocean_assert(found <= maxValues);
370 
371  return found;
372 }
373 
374 template <typename T>
375 inline unsigned int KdTree<T>::dimension() const
376 {
377  return dimension_;
378 }
379 
380 template <typename T>
381 inline size_t KdTree<T>::size() const
382 {
383  return size_;
384 }
385 
386 template <typename T>
387 void KdTree<T>::insertLeft(Node& parent, const T** values, const size_t number, const unsigned int depth)
388 {
389 #ifdef OCEAN_DEBUG
390  ocean_assert(parent);
391  ocean_assert(values && number > 0);
392 
393  for (size_t n = 0; n < number; ++n)
394  {
395  ocean_assert(values[n][(depth - 1) % dimension_] <= parent.value()[(depth - 1) % dimension_]);
396  }
397 #endif
398 
399  ocean_assert(!parent.left());
400 
401  const unsigned int index = depth % dimension_;
402 
403  const T* median = nullptr;
404  Pointers left, right;
405 
406  distribute(values, number, index, median, left, right);
407 
408  parent.setLeft(std::make_unique<Node>(median));
409 
410  if (!left.empty())
411  {
412  insertLeft(*parent.left(), left.data(), left.size(), depth + 1);
413  }
414 
415  if (!right.empty())
416  {
417  insertRight(*parent.left(), right.data(), right.size(), depth + 1);
418  }
419 }
420 
421 template <typename T>
422 void KdTree<T>::insertRight(Node& parent, const T** values, const size_t number, const unsigned int depth)
423 {
424 #ifdef OCEAN_DEBUG
425  ocean_assert(parent);
426  ocean_assert(values && number > 0);
427 
428  for (size_t n = 0; n < number; ++n)
429  {
430  ocean_assert(parent.value()[(depth - 1) % dimension_] < values[n][(depth - 1) % dimension_]);
431  }
432 #endif
433 
434  ocean_assert(!parent.right());
435 
436  const unsigned int index = depth % dimension_;
437 
438  const T* median = nullptr;
439  Pointers left, right;
440 
441  distribute(values, number, index, median, left, right);
442 
443  parent.setRight(std::make_unique<Node>(median));
444 
445  if (!left.empty())
446  {
447  insertLeft(*parent.right(), left.data(), left.size(), depth + 1);
448  }
449 
450  if (!right.empty())
451  {
452  insertRight(*parent.right(), right.data(), right.size(), depth + 1);
453  }
454 }
455 
456 template <typename T>
457 void KdTree<T>::nearestNeighbor(Node& node, const T* value, const T*& nearest, typename SquareValueTyper<T>::Type& distance, const unsigned int index) const
458 {
459  ocean_assert(node && value);
460  ocean_assert(index < dimension_);
461 
462  const typename SquareValueTyper<T>::Type localDistance = determineSquareDistance(value, node.value());
463  if (localDistance < distance)
464  {
465  distance = localDistance;
466  nearest = node.value();
467  }
468 
469  const unsigned int nextIndex = (index + 1u) % dimension_;
470 
471  // depth-first-search
472  if (value[index] <= node.value()[index])
473  {
474  if (node.left())
475  {
476  nearestNeighbor(*node.left(), value, nearest, distance, nextIndex);
477  }
478 
479  // check the neighboring branch not covered by the depth-first-search
480  if (node.right() && sqr(value[index] - node.value()[index]) < distance)
481  {
482  nearestNeighbor(*node.right(), value, nearest, distance, nextIndex);
483  }
484  }
485  else
486  {
487  ocean_assert(value[index] > node.value()[index]);
488 
489  if (node.right())
490  {
491  nearestNeighbor(*node.right(), value, nearest, distance, nextIndex);
492  }
493 
494  // check the neighboring branch not covered by the depth-first-search
495  if (node.left() && sqr(value[index] - node.value()[index]) < distance)
496  {
497  nearestNeighbor(*node.left(), value, nearest, distance, nextIndex);
498  }
499  }
500 }
501 
502 template <typename T>
503 size_t KdTree<T>::radiusSearch(Node& node, const T* value, const typename SquareValueTyper<T>::Type radius, const T** values, const size_t maxValues, const unsigned int index) const
504 {
505  ocean_assert(node && value && values);
506  ocean_assert(index < dimension_);
507 
508  if (maxValues == 0)
509  {
510  return 0;
511  }
512 
513  T const ** current = values;
514  T const ** const end = values + maxValues;
515 
516  const typename SquareValueTyper<T>::Type localDistance = determineSquareDistance(value, node.value());
517  if (localDistance <= radius)
518  {
519  *current++ = node.value();
520 
521  if (current == end)
522  {
523  return maxValues;
524  }
525  }
526 
527  const unsigned int nextIndex = (index + 1u) % dimension_;
528 
529  // depth-first-search
530  if (value[index] <= node.value()[index])
531  {
532  if (node.left())
533  {
534  current += radiusSearch(*node.left(), value, radius, current, end - current, nextIndex);
535  if (current == end)
536  {
537  return maxValues;
538  }
539  }
540 
541  // check the neighboring branch not covered by the depth-first-search
542  if (node.right() && sqr(value[index] - node.value()[index]) <= radius)
543  {
544  current += radiusSearch(*node.right(), value, radius, current, end - current, nextIndex);
545  }
546  }
547  else
548  {
549  ocean_assert(value[index] > node.value()[index]);
550 
551  if (node.right())
552  {
553  current += radiusSearch(*node.right(), value, radius, current, end - current, nextIndex);
554  if (current == end)
555  {
556  return maxValues;
557  }
558  }
559 
560  // check the neighboring branch not covered by the depth-first-search
561  if (node.left() && sqr(value[index] - node.value()[index]) <= radius)
562  {
563  current += radiusSearch(*node.left(), value, radius, current, end - current, nextIndex);
564  }
565  }
566 
567  ocean_assert(current <= end);
568 
569  return current - values;
570 }
571 
572 template <typename T>
573 T KdTree<T>::median(const T** values, const size_t number, const unsigned int index)
574 {
575  ocean_assert(values && number > 0);
576 
577  Elements elements(number);
578  for (size_t n = 0; n < number; ++n)
579  {
580  elements[n] = values[n][index];
581  }
582 
583  return Median::median<T>((T*)elements.data(), elements.size());
584 }
585 
586 template <typename T>
587 void KdTree<T>::distribute(const T** values, const size_t number, const unsigned int index, const T*& medianValue, Pointers& leftValues, Pointers& rightValues)
588 {
589  ocean_assert(values && number > 0);
590 
591  const T middle = median(values, number, index);
592 
593  leftValues.reserve(number / 2);
594  rightValues.reserve(number / 2);
595 
596  ocean_assert(leftValues.empty() && rightValues.empty());
597 
598  bool medianFound = false;
599 
600  for (size_t n = 0; n < number; ++n)
601  {
602  if (values[n][index] < middle || (values[n][index] == middle && medianFound))
603  {
604  leftValues.push_back(values[n]);
605  }
606  else if (middle < values[n][index])
607  {
608  rightValues.push_back(values[n]);
609  }
610  else
611  {
612  ocean_assert(values[n][index] == middle);
613  ocean_assert(!medianFound);
614 
615  medianValue = values[n];
616  medianFound = true;
617  }
618  }
619 
620  ocean_assert(medianFound);
621 }
622 
623 template <typename T>
624 inline typename SquareValueTyper<T>::Type KdTree<T>::determineSquareDistance(const T* first, const T* second) const
625 {
626  ocean_assert(first && second);
627 
628  typename SquareValueTyper<T>::Type ssd = 0;
629 
630  for (size_t n = 0; n < dimension_; ++n)
631  {
632  ssd += sqr(first[n] - second[n]);
633  }
634 
635  return ssd;
636 }
637 
638 }
639 
640 #endif // META_OCEAN_BASE_KD_TREE_H
This class defines a node of the k-d tree.
Definition: KdTree.h:37
const T * value_
Node value.
Definition: KdTree.h:95
Node * left()
Returns the left child of this node.
Definition: KdTree.h:257
std::unique_ptr< Node > left_
Left node.
Definition: KdTree.h:98
Node * right()
Returns the right child of this node.
Definition: KdTree.h:263
~Node()=default
Destructs a node.
void setRight(std::unique_ptr< Node > &&right)
Sets the right child of this node.
Definition: KdTree.h:282
void setLeft(std::unique_ptr< Node > &&left)
Sets the left child of this node.
Definition: KdTree.h:275
Node()=default
Creates an empty node.
std::unique_ptr< Node > right_
Right node.
Definition: KdTree.h:101
const T * value()
Returns value of this node.
Definition: KdTree.h:269
This class implements a k-d tree.
Definition: KdTree.h:30
std::vector< const T * > Pointers
Definition of a vector holding single pointers.
Definition: KdTree.h:112
const T * nearestNeighbor(const T *value, typename SquareValueTyper< T >::Type &distance) const
Applies a nearest neighbor search for a given value.
Definition: KdTree.h:339
const unsigned int dimension_
Number of dimensions.
Definition: KdTree.h:246
std::vector< T > Elements
Definition of a vector holding single elements.
Definition: KdTree.h:107
std::unique_ptr< Node > root_
Root node of this tree.
Definition: KdTree.h:240
static T median(const T **values, const size_t number, const unsigned int index)
Returns the median for a given set of values.
Definition: KdTree.h:573
~KdTree()=default
Destructs a k-d tree.
static void distribute(const T **values, const size_t number, const unsigned int index, const T *&medianValue, Pointers &leftValues, Pointers &rightValues)
Distributes a given set of values into two subset according to the median.
Definition: KdTree.h:587
KdTree(const unsigned int dimension)
Creates a new k-d tree.
Definition: KdTree.h:295
size_t size() const
Returns the number of tree nodes.
Definition: KdTree.h:381
void insertLeft(Node &parent, const T **values, const size_t number, const unsigned int depth)
Inserts a set of values to a given parent as left children.
Definition: KdTree.h:387
bool insert(const T **values, const size_t number)
Inserts a set of values to this empty tree.
Definition: KdTree.h:302
size_t radiusSearch(const T *value, const typename SquareValueTyper< T >::Type radius, const T **values, const size_t maxValues) const
Applies a radius search for neighbors of a given value.
Definition: KdTree.h:357
size_t size_
Number of nodes.
Definition: KdTree.h:243
SquareValueTyper< T >::Type determineSquareDistance(const T *first, const T *second) const
Determines the square distance between two values.
Definition: KdTree.h:624
unsigned int dimension() const
Returns the dimension of the tree's values.
Definition: KdTree.h:375
void insertRight(Node &parent, const T **values, const size_t number, const unsigned int depth)
Inserts a set of values to a given parent as right children.
Definition: KdTree.h:422
T Type
Definition of the data type for the square value.
Definition: DataType.h:132
unsigned int sqr(const char value)
Returns the square value of a given value.
Definition: base/Utilities.h:1029
The namespace covering the entire Ocean framework.
Definition: Accessor.h:15