150template <
class O,
class OT>
151typename KdTree<O, OT>::Neighbour
156 LASS_THROW(
"can't locate nearest neighbour in empty KdTree");
159 Neighbour best(*end_, maxRadius);
167 size_t stackSize = 0;
168 stack[stackSize].index = 0;
169 stack[stackSize++].sqrDelta = -1;
171 while (stackSize > 0)
173 const Visit visit = stack[--stackSize];
174 if (visit.sqrDelta > best.squaredDistance())
178 if (visit.index >= heap_.size() || heap_[visit.index].object() == *end_)
182 const Node& node = heap_[visit.index];
183 const TPoint& pivot = node.position();
185 const TValue sqrDistance = squaredDistance(pivot, target);
186 if (sqrDistance < best.squaredDistance())
188 best = Neighbour(node.object(), sqrDistance);
191 const TAxis split = node.axis();
192 if (split == dummyAxis_)
197 const TValue delta = target[split] - pivot[split];
201 stack[stackSize].index = 2 * visit.index + 2;
202 stack[stackSize++].sqrDelta =
num::sqr(delta);
203 stack[stackSize].index = 2 * visit.index + 1;
204 stack[stackSize++].sqrDelta = 0;
209 stack[stackSize].index = 2 * visit.index + 1;
210 stack[stackSize++].sqrDelta =
num::sqr(delta);
211 stack[stackSize].index = 2 * visit.index + 2;
212 stack[stackSize++].sqrDelta = 0;
216 doNearestNeighbour(0, target, best);
237template <
class O,
class OT>
238typename KdTree<O, OT>::TValue
240 const TPoint& target, TParam maxRadius,
size_t maxCount,
241 TNeighbourhood& neighbourhood)
const
245 LASS_THROW(
"can't perform range search in empty KdTree");
250 neighbourhood.clear();
251 rangeSearch(target, maxRadius, std::back_inserter(neighbourhood));
254 TValue maxSquaredDistance = TValue();
255 const typename TNeighbourhood::const_iterator end = neighbourhood.end();
256 for (
typename TNeighbourhood::const_iterator i = neighbourhood.begin(); i != end; ++i)
258 maxSquaredDistance = std::max(maxSquaredDistance, i->squaredDistance());
260 return maxSquaredDistance;
263 maxCount = std::min(maxCount, heap_.size());
264 neighbourhood.resize(maxCount + 1);
266 typename TNeighbourhood::iterator last =
rangeSearch(
267 target, maxRadius, maxCount, neighbourhood.begin());
268 neighbourhood.erase(last, neighbourhood.end());
270 if (neighbourhood.empty())
274 return neighbourhood.front().squaredDistance();
294template <
class O,
class OT>
295template <
typename OutputIterator>
299 if (
isEmpty() || maxRadius == 0)
303 const TValue squaredRadius = maxRadius * maxRadius;
304 return doRangeSearch(0, target, squaredRadius, first);
331template <
class O,
class OT>
332template <
typename RandomAccessIterator>
335 RandomAccessIterator first)
const
337 if (
isEmpty() || maxRadius == 0)
341 TValue squaredRadius = maxRadius * maxRadius;
342 return doRangeSearch(0, target, squaredRadius, maxCount, first, first);
349template <
class O,
class OT>
352 heap_.swap(other.heap_);
353 end_.swap(other.end_);
360template <
class O,
class OT>
363 return heap_.empty();
370template <
class O,
class OT>
379template <
class O,
class OT>
inline
380const typename KdTree<O, OT>::TObjectIterator
381KdTree<O, OT>::end()
const
390template <
class O,
class OT>
391KdTree<O, OT>::Neighbour::Neighbour():
399template <
class O,
class OT>
400KdTree<O, OT>::Neighbour::Neighbour(TObjectIterator
object, TValue squaredDistance):
402 squaredDistance_(squaredDistance)
408template <
class O,
class OT>
409inline typename KdTree<O, OT>::TObjectIterator
410KdTree<O, OT>::Neighbour::object()
const
417template <
class O,
class OT>
421 return TObjectTraits::position(object_);
426template <
class O,
class OT>
430 return squaredDistance_;
435template <
class O,
class OT>
444template <
class O,
class OT>
453template <
class O,
class OT>
457 return squaredDistance_ < other.squaredDistance_;
468template <
class O,
class OT>
476 if (last == first + 1)
479 assignNode(index, *first, dummyAxis_);
485 const TAxis
split = findSplitAxis(first, last);
486 const size_t size =
static_cast<size_t>(last - first);
487 const TIteratorIterator median = first + size / 2;
488 std::nth_element(first, median, last, LessDim(split));
489 assignNode(index, *median, split);
491 balance(2 * index + 1, first, median);
492 balance(2 * index + 2, median + 1, last);
497template <
class O,
class OT>
501 TPoint min = TObjectTraits::position(*first);
504 for (TIteratorIterator i = first + 1; i != last; ++i)
506 const TPoint position = TObjectTraits::position(*i);
507 for (TAxis k = 0; k < dimension; ++k)
509 min[k] = std::min(min[k], position[k]);
510 max[k] = std::max(max[k], position[k]);
515 TValue maxDistance = max[0] - min[0];
516 for (TAxis k = 1; k < dimension; ++k)
518 const TValue distance = max[k] - min[k];
519 if (distance > maxDistance)
522 maxDistance = distance;
531template <
class O,
class OT>
534 if (heap_.size() <= index)
536 heap_.resize(index + 1, Node(*end_));
538 heap_[index] = Node(
object, TObjectTraits::position(
object), splitAxis);
543template <
class O,
class OT>
546 const size_t size = heap_.size();
547 if (index >= size || heap_[index].
object() == *end_)
551 const Node& node = heap_[index];
553 const TAxis
split = node.axis();
554 if (split == dummyAxis_)
559 const TPoint& pivot = node.position();
560 const TValue delta = target[
split] - pivot[
split];
561 const bool isLeftSide = delta < 0;
562 const size_t result = findNode(2 * index + (isLeftSide ? 1 : 2), target);
563 return result != size ? result : index;
568template <
class O,
class OT>
571 if (index >= heap_.size() || heap_[index].object() == *end_)
575 const Node& node = heap_[index];
576 const TPoint& pivot = node.position();
578 const TValue sqrDistance = squaredDistance(pivot, target);
579 if (sqrDistance < best.squaredDistance())
581 best = Neighbour(node.object(), sqrDistance);
584 const TAxis
split = node.axis();
585 if (split != dummyAxis_)
587 const TValue delta = target[
split] - pivot[
split];
591 doNearestNeighbour(2 * index + 1, target, best);
592 if (
num::sqr(delta) < best.squaredDistance())
594 doNearestNeighbour(2 * index + 2, target, best);
600 doNearestNeighbour(2 * index + 2, target, best);
601 if (
num::sqr(delta) < best.squaredDistance())
603 doNearestNeighbour(2 * index + 1, target, best);
611template <
class O,
class OT>
612template <
typename OutputIterator>
614 size_t index,
const TPoint& target, TParam squaredDistance,
615 OutputIterator output)
const
617 if (index >= heap_.size() || heap_[index].object() == *end_)
621 const Node& node = heap_[index];
623 const TPoint& pivot = node.position();
624 const TAxis
split = node.axis();
625 if (split != dummyAxis_)
627 const TValue delta = target[
split] - pivot[
split];
628 if (delta < TValue())
631 output = doRangeSearch(2 * index + 1, target, squaredDistance, output);
632 if (
num::sqr(delta) < squaredDistance)
634 output = doRangeSearch(2 * index + 2, target, squaredDistance, output);
640 output = doRangeSearch(2 * index + 2, target, squaredDistance, output);
641 if (
num::sqr(delta) < squaredDistance)
643 output = doRangeSearch(2 * index + 1, target, squaredDistance, output);
648 const TValue sqrDistance = this->squaredDistance(pivot, target);
649 if (sqrDistance < squaredDistance)
651 *output++ = Neighbour(node.object(), sqrDistance);
658template <
class O,
class OT>
659template <
typename RandomIterator>
661 size_t index,
const TPoint& target, TReference squaredRadius,
size_t maxCount,
662 RandomIterator first, RandomIterator last)
const
664 if (index >= heap_.size() || heap_[index].object() == *end_)
668 const Node& node = heap_[index];
670 const TPoint& pivot = node.position();
671 const TAxis
split = node.axis();
672 if (split != dummyAxis_)
674 const TValue delta = target[
split] - pivot[
split];
675 if (delta < TValue())
678 last = doRangeSearch(
679 2 * index + 1, target, squaredRadius, maxCount, first, last);
680 if (
num::sqr(delta) < squaredRadius)
682 last = doRangeSearch(
683 2 * index + 2, target, squaredRadius, maxCount, first, last);
689 last = doRangeSearch(
690 2 * index + 2, target, squaredRadius, maxCount, first, last);
691 if (
num::sqr(delta) < squaredRadius)
693 last = doRangeSearch(
694 2 * index + 1, target, squaredRadius, maxCount, first, last);
699 const TValue sqrDistance = squaredDistance(pivot, target);
700 if (sqrDistance < squaredRadius)
702 *last++ = Neighbour(node.object(), sqrDistance);
703 std::push_heap(first, last);
704 LASS_ASSERT(last >= first);
705 if (
static_cast<size_t>(last - first) > maxCount)
707 std::pop_heap(first, last);
709 squaredRadius = first->squaredDistance();
717template <
class O,
class OT>
inline
721 TValue result = TValue();
722 for (
unsigned k = 0; k < dimension; ++k)
731#ifdef LASS_SPAT_KD_TREE_DIAGNOSTICS
732template <
class O,
class OT>
735 typedef typename meta::Select< meta::Bool<dimension == 2>, prim::Aabb2D<TValue>, prim::Aabb3D<TValue> >::Type TAabb;
739 Visitor(
const TNodes& heap, TObjectIterator end):
740 xml_(
"kdtree.xml",
"diagnostics"),
746 void visit(
size_t index,
const TAabb& aabb)
748#if LASS_COMPILER_TYPE == LASS_COMPILER_TYPE_MSVC
749 using lass::prim::operator<<;
753 if (index >= heap_.size() || heap_[index].object() == end_)
757 const Node& node = heap_[index];
759 const typename TObjectTraits::TPoint& pivot = node.position();
762 const TAxis
split = node.axis();
763 if (split == dummyAxis_)
769 typename TAabb::TPoint max = less.max();
772 visit(2 * index + 1, less);
774 TAabb greater = aabb;
775 typename TAabb::TPoint min = greater.min();
778 visit(2 * index + 2, greater);
784 TObjectIterator end_;
788 for (
size_t i = 0, n = heap_.size(); i < n; ++i)
790 if (heap_[index].
object() == *end_)
794 aabb += TObjectTraits::position(heap_[i].
object());
797 Visitor visitor(heap_, *end_);
798 visitor.visit(0, aabb);