OpenTTD Source  1.11.2
kdtree.hpp
Go to the documentation of this file.
1 /*
2  * This file is part of OpenTTD.
3  * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
4  * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
5  * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
6  */
7 
10 #ifndef KDTREE_HPP
11 #define KDTREE_HPP
12 
13 #include "../stdafx.h"
14 #include <vector>
15 #include <limits>
16 
36 template <typename T, typename TxyFunc, typename CoordT, typename DistT>
37 class Kdtree {
39  struct node {
40  T element;
41  size_t left;
42  size_t right;
43 
45  };
46 
47  static const size_t INVALID_NODE = SIZE_MAX;
48 
49  std::vector<node> nodes;
50  std::vector<size_t> free_list;
51  size_t root;
52  TxyFunc xyfunc;
53  size_t unbalanced;
54 
56  size_t AddNode(const T &element)
57  {
58  if (this->free_list.size() == 0) {
59  this->nodes.emplace_back(element);
60  return this->nodes.size() - 1;
61  } else {
62  size_t newidx = this->free_list.back();
63  this->free_list.pop_back();
64  this->nodes[newidx] = node{ element };
65  return newidx;
66  }
67  }
68 
70  template <typename It>
71  CoordT SelectSplitCoord(It begin, It end, int level)
72  {
73  It mid = begin + (end - begin) / 2;
74  std::nth_element(begin, mid, end, [&](T a, T b) { return this->xyfunc(a, level % 2) < this->xyfunc(b, level % 2); });
75  return this->xyfunc(*mid, level % 2);
76  }
77 
79  template <typename It>
80  size_t BuildSubtree(It begin, It end, int level)
81  {
82  ptrdiff_t count = end - begin;
83 
84  if (count == 0) {
85  return INVALID_NODE;
86  } else if (count == 1) {
87  return this->AddNode(*begin);
88  } else if (count > 1) {
89  CoordT split_coord = SelectSplitCoord(begin, end, level);
90  It split = std::partition(begin, end, [&](T v) { return this->xyfunc(v, level % 2) < split_coord; });
91  size_t newidx = this->AddNode(*split);
92  this->nodes[newidx].left = this->BuildSubtree(begin, split, level + 1);
93  this->nodes[newidx].right = this->BuildSubtree(split + 1, end, level + 1);
94  return newidx;
95  } else {
96  NOT_REACHED();
97  }
98  }
99 
101  bool Rebuild(const T *include_element, const T *exclude_element)
102  {
103  size_t initial_count = this->Count();
104  if (initial_count < 8) return false; // arbitrary value for "not worth rebalancing"
105 
106  T root_element = this->nodes[this->root].element;
107  std::vector<T> elements = this->FreeSubtree(this->root);
108  elements.push_back(root_element);
109 
110  if (include_element != nullptr) {
111  elements.push_back(*include_element);
112  initial_count++;
113  }
114  if (exclude_element != nullptr) {
115  typename std::vector<T>::iterator removed = std::remove(elements.begin(), elements.end(), *exclude_element);
116  elements.erase(removed, elements.end());
117  initial_count--;
118  }
119 
120  this->Build(elements.begin(), elements.end());
121  assert(initial_count == this->Count());
122  return true;
123  }
124 
126  void InsertRecursive(const T &element, size_t node_idx, int level)
127  {
128  /* Dimension index of current level */
129  int dim = level % 2;
130  /* Node reference */
131  node &n = this->nodes[node_idx];
132 
133  /* Coordinate of element splitting at this node */
134  CoordT nc = this->xyfunc(n.element, dim);
135  /* Coordinate of the new element */
136  CoordT ec = this->xyfunc(element, dim);
137  /* Which side to insert on */
138  size_t &next = (ec < nc) ? n.left : n.right;
139 
140  if (next == INVALID_NODE) {
141  /* New leaf */
142  size_t newidx = this->AddNode(element);
143  /* Vector may have been reallocated at this point, n and next are invalid */
144  node &nn = this->nodes[node_idx];
145  if (ec < nc) nn.left = newidx; else nn.right = newidx;
146  } else {
147  this->InsertRecursive(element, next, level + 1);
148  }
149  }
150 
155  std::vector<T> FreeSubtree(size_t node_idx)
156  {
157  std::vector<T> subtree_elements;
158  node &n = this->nodes[node_idx];
159 
160  /* We'll be appending items to the free_list, get index of our first item */
161  size_t first_free = this->free_list.size();
162  /* Prepare the descent with our children */
163  if (n.left != INVALID_NODE) this->free_list.push_back(n.left);
164  if (n.right != INVALID_NODE) this->free_list.push_back(n.right);
165  n.left = n.right = INVALID_NODE;
166 
167  /* Recursively free the nodes being collected */
168  for (size_t i = first_free; i < this->free_list.size(); i++) {
169  node &fn = this->nodes[this->free_list[i]];
170  subtree_elements.push_back(fn.element);
171  if (fn.left != INVALID_NODE) this->free_list.push_back(fn.left);
172  if (fn.right != INVALID_NODE) this->free_list.push_back(fn.right);
173  fn.left = fn.right = INVALID_NODE;
174  }
175 
176  return subtree_elements;
177  }
178 
186  size_t RemoveRecursive(const T &element, size_t node_idx, int level)
187  {
188  /* Node reference */
189  node &n = this->nodes[node_idx];
190 
191  if (n.element == element) {
192  /* Remove this one */
193  this->free_list.push_back(node_idx);
194  if (n.left == INVALID_NODE && n.right == INVALID_NODE) {
195  /* Simple case, leaf, new child node for parent is "none" */
196  return INVALID_NODE;
197  } else {
198  /* Complex case, rebuild the sub-tree */
199  std::vector<T> subtree_elements = this->FreeSubtree(node_idx);
200  return this->BuildSubtree(subtree_elements.begin(), subtree_elements.end(), level);;
201  }
202  } else {
203  /* Search in a sub-tree */
204  /* Dimension index of current level */
205  int dim = level % 2;
206  /* Coordinate of element splitting at this node */
207  CoordT nc = this->xyfunc(n.element, dim);
208  /* Coordinate of the element being removed */
209  CoordT ec = this->xyfunc(element, dim);
210  /* Which side to remove from */
211  size_t next = (ec < nc) ? n.left : n.right;
212  assert(next != INVALID_NODE); // node must exist somewhere and must be found before a leaf is reached
213  /* Descend */
214  size_t new_branch = this->RemoveRecursive(element, next, level + 1);
215  if (new_branch != next) {
216  /* Vector may have been reallocated at this point, n and next are invalid */
217  node &nn = this->nodes[node_idx];
218  if (ec < nc) nn.left = new_branch; else nn.right = new_branch;
219  }
220  return node_idx;
221  }
222  }
223 
224 
225  DistT ManhattanDistance(const T &element, CoordT x, CoordT y) const
226  {
227  return abs((DistT)this->xyfunc(element, 0) - (DistT)x) + abs((DistT)this->xyfunc(element, 1) - (DistT)y);
228  }
229 
231  using node_distance = std::pair<T, DistT>;
234  {
235  if (a.second < b.second) return a;
236  if (b.second < a.second) return b;
237  if (a.first < b.first) return a;
238  if (b.first < a.first) return b;
239  NOT_REACHED(); // a.first == b.first: same element must not be inserted twice
240  }
242  node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit = std::numeric_limits<DistT>::max()) const
243  {
244  /* Dimension index of current level */
245  int dim = level % 2;
246  /* Node reference */
247  const node &n = this->nodes[node_idx];
248 
249  /* Coordinate of element splitting at this node */
250  CoordT c = this->xyfunc(n.element, dim);
251  /* This node's distance to target */
252  DistT thisdist = ManhattanDistance(n.element, xy[0], xy[1]);
253  /* Assume this node is the best choice for now */
254  node_distance best = std::make_pair(n.element, thisdist);
255 
256  /* Next node to visit */
257  size_t next = (xy[dim] < c) ? n.left : n.right;
258  if (next != INVALID_NODE) {
259  /* Check if there is a better node down the tree */
260  best = SelectNearestNodeDistance(best, this->FindNearestRecursive(xy, next, level + 1));
261  }
262 
263  limit = std::min(best.second, limit);
264 
265  /* Check if the distance from current best is worse than distance from target to splitting line,
266  * if it is we also need to check the other side of the split. */
267  size_t opposite = (xy[dim] >= c) ? n.left : n.right; // reverse of above
268  if (opposite != INVALID_NODE && limit >= abs((int)xy[dim] - (int)c)) {
269  node_distance other_candidate = this->FindNearestRecursive(xy, opposite, level + 1, limit);
270  best = SelectNearestNodeDistance(best, other_candidate);
271  }
272 
273  return best;
274  }
275 
276  template <typename Outputter>
277  void FindContainedRecursive(CoordT p1[2], CoordT p2[2], size_t node_idx, int level, Outputter outputter) const
278  {
279  /* Dimension index of current level */
280  int dim = level % 2;
281  /* Node reference */
282  const node &n = this->nodes[node_idx];
283 
284  /* Coordinate of element splitting at this node */
285  CoordT ec = this->xyfunc(n.element, dim);
286  /* Opposite coordinate of element */
287  CoordT oc = this->xyfunc(n.element, 1 - dim);
288 
289  /* Test if this element is within rectangle */
290  if (ec >= p1[dim] && ec < p2[dim] && oc >= p1[1 - dim] && oc < p2[1 - dim]) outputter(n.element);
291 
292  /* Recurse left if part of rectangle is left of split */
293  if (p1[dim] < ec && n.left != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.left, level + 1, outputter);
294 
295  /* Recurse right if part of rectangle is right of split */
296  if (p2[dim] > ec && n.right != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.right, level + 1, outputter);
297  }
298 
300  size_t CountValue(const T &element, size_t node_idx) const
301  {
302  if (node_idx == INVALID_NODE) return 0;
303  const node &n = this->nodes[node_idx];
304  return CountValue(element, n.left) + CountValue(element, n.right) + ((n.element == element) ? 1 : 0);
305  }
306 
307  void IncrementUnbalanced(size_t amount = 1)
308  {
309  this->unbalanced += amount;
310  }
311 
314  {
315  size_t count = this->Count();
316  if (count < 8) return false;
317  return this->unbalanced > this->Count() / 4;
318  }
319 
321  void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y)
322  {
323  if (node_idx == INVALID_NODE) return;
324 
325  const node &n = this->nodes[node_idx];
326  CoordT cx = this->xyfunc(n.element, 0);
327  CoordT cy = this->xyfunc(n.element, 1);
328 
329  assert(cx >= min_x);
330  assert(cx < max_x);
331  assert(cy >= min_y);
332  assert(cy < max_y);
333 
334  if (level % 2 == 0) {
335  // split in dimension 0 = x
336  CheckInvariant(n.left, level + 1, min_x, cx, min_y, max_y);
337  CheckInvariant(n.right, level + 1, cx, max_x, min_y, max_y);
338  } else {
339  // split in dimension 1 = y
340  CheckInvariant(n.left, level + 1, min_x, max_x, min_y, cy);
341  CheckInvariant(n.right, level + 1, min_x, max_x, cy, max_y);
342  }
343  }
344 
347  {
348 #ifdef KDTREE_DEBUG
349  CheckInvariant(this->root, 0, std::numeric_limits<CoordT>::min(), std::numeric_limits<CoordT>::max(), std::numeric_limits<CoordT>::min(), std::numeric_limits<CoordT>::max());
350 #endif
351  }
352 
353 public:
356 
363  template <typename It>
364  void Build(It begin, It end)
365  {
366  this->nodes.clear();
367  this->free_list.clear();
368  this->unbalanced = 0;
369  if (begin == end) return;
370  this->nodes.reserve(end - begin);
371 
372  this->root = this->BuildSubtree(begin, end, 0);
373  CheckInvariant();
374  }
375 
379  void Clear()
380  {
381  this->nodes.clear();
382  this->free_list.clear();
383  this->unbalanced = 0;
384  return;
385  }
386 
390  void Rebuild()
391  {
392  this->Rebuild(nullptr, nullptr);
393  }
394 
400  void Insert(const T &element)
401  {
402  if (this->Count() == 0) {
403  this->root = this->AddNode(element);
404  } else {
405  if (!this->IsUnbalanced() || !this->Rebuild(&element, nullptr)) {
406  this->InsertRecursive(element, this->root, 0);
407  this->IncrementUnbalanced();
408  }
409  CheckInvariant();
410  }
411  }
412 
419  void Remove(const T &element)
420  {
421  size_t count = this->Count();
422  if (count == 0) return;
423  if (!this->IsUnbalanced() || !this->Rebuild(nullptr, &element)) {
424  /* If the removed element is the root node, this modifies this->root */
425  this->root = this->RemoveRecursive(element, this->root, 0);
426  this->IncrementUnbalanced();
427  }
428  CheckInvariant();
429  }
430 
432  size_t Count() const
433  {
434  assert(this->free_list.size() <= this->nodes.size());
435  return this->nodes.size() - this->free_list.size();
436  }
437 
443  T FindNearest(CoordT x, CoordT y) const
444  {
445  assert(this->Count() > 0);
446 
447  CoordT xy[2] = { x, y };
448  return this->FindNearestRecursive(xy, this->root, 0).first;
449  }
450 
460  template <typename Outputter>
461  void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, Outputter outputter) const
462  {
463  assert(x1 < x2);
464  assert(y1 < y2);
465 
466  if (this->Count() == 0) return;
467 
468  CoordT p1[2] = { x1, y1 };
469  CoordT p2[2] = { x2, y2 };
470  this->FindContainedRecursive(p1, p2, this->root, 0, outputter);
471  }
472 
477  std::vector<T> FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
478  {
479  std::vector<T> result;
480  this->FindContained(x1, y1, x2, y2, [&result](T e) {result.push_back(e); });
481  return result;
482  }
483 };
484 
485 #endif
Kdtree::InsertRecursive
void InsertRecursive(const T &element, size_t node_idx, int level)
Insert one element in the tree somewhere below node_idx.
Definition: kdtree.hpp:126
Kdtree::unbalanced
size_t unbalanced
Number approximating how unbalanced the tree might be.
Definition: kdtree.hpp:53
Kdtree::FreeSubtree
std::vector< T > FreeSubtree(size_t node_idx)
Free all children of the given node.
Definition: kdtree.hpp:155
Kdtree::node::right
size_t right
Index of node to the right, INVALID_NODE if none.
Definition: kdtree.hpp:42
Kdtree
K-dimensional tree, specialised for 2-dimensional space.
Definition: kdtree.hpp:37
Kdtree::INVALID_NODE
static const size_t INVALID_NODE
Index value indicating no-such-node.
Definition: kdtree.hpp:47
Kdtree::SelectSplitCoord
CoordT SelectSplitCoord(It begin, It end, int level)
Find a coordinate value to split a range of elements at.
Definition: kdtree.hpp:71
Kdtree::SelectNearestNodeDistance
static node_distance SelectNearestNodeDistance(const node_distance &a, const node_distance &b)
Ordering function for node_distance objects, elements with equal distance are ordered by less-than co...
Definition: kdtree.hpp:233
Kdtree::Rebuild
bool Rebuild(const T *include_element, const T *exclude_element)
Rebuild the tree with all existing elements, optionally adding or removing one more.
Definition: kdtree.hpp:101
Kdtree::IsUnbalanced
bool IsUnbalanced()
Check if the entire tree is in need of rebuilding.
Definition: kdtree.hpp:313
Kdtree::free_list
std::vector< size_t > free_list
List of dead indices in the nodes vector.
Definition: kdtree.hpp:50
Kdtree::Kdtree
Kdtree(TxyFunc xyfunc)
Construct a new Kdtree with the given xyfunc.
Definition: kdtree.hpp:355
Kdtree::Build
void Build(It begin, It end)
Clear and rebuild the tree from a new sequence of elements,.
Definition: kdtree.hpp:364
Kdtree::Count
size_t Count() const
Get number of elements stored in tree.
Definition: kdtree.hpp:432
Kdtree::Remove
void Remove(const T &element)
Remove a single element from the tree, if it exists.
Definition: kdtree.hpp:419
Kdtree::FindContained
void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, Outputter outputter) const
Find all items contained within the given rectangle.
Definition: kdtree.hpp:461
Kdtree::Insert
void Insert(const T &element)
Insert a single element in the tree.
Definition: kdtree.hpp:400
Kdtree::xyfunc
TxyFunc xyfunc
Functor to extract a coordinate from an element.
Definition: kdtree.hpp:52
Kdtree::node_distance
std::pair< T, DistT > node_distance
A data element and its distance to a searched-for point.
Definition: kdtree.hpp:231
Kdtree::CheckInvariant
void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y)
Verify that the invariant is true for a sub-tree, assert if not.
Definition: kdtree.hpp:321
Kdtree::FindContained
std::vector< T > FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
Find all items contained within the given rectangle.
Definition: kdtree.hpp:477
Kdtree::RemoveRecursive
size_t RemoveRecursive(const T &element, size_t node_idx, int level)
Find and remove one element from the tree.
Definition: kdtree.hpp:186
Kdtree::FindNearest
T FindNearest(CoordT x, CoordT y) const
Find the element closest to given coordinate, in Manhattan distance.
Definition: kdtree.hpp:443
Kdtree::BuildSubtree
size_t BuildSubtree(It begin, It end, int level)
Construct a subtree from elements between begin and end iterators, return index of root.
Definition: kdtree.hpp:80
abs
static T abs(const T a)
Returns the absolute value of (scalar) variable.
Definition: math_func.hpp:21
Kdtree::CheckInvariant
void CheckInvariant()
Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined.
Definition: kdtree.hpp:346
Kdtree::Clear
void Clear()
Clear the tree.
Definition: kdtree.hpp:379
Kdtree::Rebuild
void Rebuild()
Reconstruct the tree with the same elements, letting it be fully balanced.
Definition: kdtree.hpp:390
Kdtree::node::element
T element
Element stored at node.
Definition: kdtree.hpp:40
Kdtree::node
Type of a node in the tree.
Definition: kdtree.hpp:39
Kdtree::root
size_t root
Index of root node.
Definition: kdtree.hpp:51
Kdtree::FindNearestRecursive
node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit=std::numeric_limits< DistT >::max()) const
Search a sub-tree for the element nearest to a given point.
Definition: kdtree.hpp:242
Kdtree::CountValue
size_t CountValue(const T &element, size_t node_idx) const
Debugging function, counts number of occurrences of an element regardless of its correct position in ...
Definition: kdtree.hpp:300
Kdtree::node::left
size_t left
Index of node to the left, INVALID_NODE if none.
Definition: kdtree.hpp:41
Kdtree::nodes
std::vector< node > nodes
Pool of all nodes in the tree.
Definition: kdtree.hpp:49
Kdtree::AddNode
size_t AddNode(const T &element)
Create one new node in the tree, return its index in the pool.
Definition: kdtree.hpp:56