5 #include <cstring> // memcpy
9 * Contains a generic implementation of static KD-trees balanced into a heap-like shape.
10 * The trees are represented by instances of KDTree class and built by a static method
11 * KDBuilder::makeTree.
18 /** Routines for computing with T[length] number fields */
20 /** Calls transformer(x,y) for every x from [i1;iEnd1) and corresponding y from [i2;?) */
21 template<class I1
,class I2
,class Transf
> inline
22 Transf
transform2( I1 i1
, I1 iEnd1
, I2 i2
, Transf transformer
) {
23 for (; i1
!=iEnd1
; ++i1
,++i2
)
27 /** Analogous to ::transform2 */
28 template<class I1
,class I2
,class I3
,class Transf
> inline
29 Transf
transform3( I1 i1
, I1 iEnd1
, I2 i2
, I3 i3
, Transf transformer
) {
30 for (; i1
!=iEnd1
; ++i1
,++i2
,++i3
)
31 transformer(*i1
,*i2
,*i3
);
35 /** Means b[i]=a[i]; Only meant for POD types */
36 template<class T
> inline T
* assign(const T
*a
,int length
,T
*b
) {
38 copy( a
, a
+length
, b
); // debugging version uses STL's copy
40 memcpy( b
, a
, length
*sizeof(T
) ); // release version uses C's memory-copy (faster)
46 /** A helper transforming structure, only to be used in ::moveToBounds_copy */
47 template<class T
,bool CheckNaNs
> struct MoveToBounds
{
48 T sqrError
; ///< square error accumulated so far
50 /** Only sets ::sqrError to zero */
54 /** Moves one coordinate of a point (in \p point) to a bounding box
55 * (in \p bounds) accumulating ::sqrError and storing result (in \p result) */
56 void operator()(const T
&point
,const T bounds
[2],T
&result
) {
57 if ( CheckNaNs
&& isNaN(point
) )
59 if ( point
< bounds
[0] ) {
60 sqrError
+= sqr(bounds
[0]-point
);
63 if ( point
> bounds
[1] ) {
64 sqrError
+= sqr(point
-bounds
[1]);
71 /** Copy_moves a vector (\p point) to the nearest point (\p result)
72 * within bounds (\p bounds) and returns SE (distance^2) */
73 template<class T
,bool CheckNaNs
> inline
74 T
moveToBounds_copy(const T
*point
,const T (*bounds
)[2],int length
,T
*result
) {
75 return transform3( point
, point
+length
, bounds
, result
76 , MoveToBounds
<T
,CheckNaNs
>() ) .sqrError
;
80 template<class T
> class KDBuilder
;
81 /** A generic static KD-tree. Construction is done by KDBuilder::makeTree static method.
82 * The searching for nearest neighbours is performed by the PointHeap subclass. */
83 template<class T
> class KDTree
{
85 friend class KDBuilder
<T
>;
86 typedef KDBuilder
<T
> Builder
;
88 /** Represents one node of the KD-tree */
90 int coord
; ///< The coordinate along which the tree splits in this node
91 T threshold
;/**< The threshold of the split
92 * (left[::coord] <= ::threshold <= right[::coord]) */
94 typedef T (*Bounds
)[2];
97 const int depth
/// The depth of the tree = ::log2ceil(::count)
98 , length
/// The length of the vectors
99 , count
; ///< The number of the vectors
101 Node
*nodes
; ///< The array of the tree-nodes (heap-like topology of the tree)
102 int *dataIDs
; ///< Data IDs for children of "leaf" nodes
103 Bounds bounds
; ///< The bounding box for all the data
105 /** Prepares to build a new KD-tree from \p count_ vectors of \p length_ elements */
106 KDTree(int length_
,int count_
)
107 : depth( log2ceil(count_
) ), length(length_
), count(count_
)
108 , nodes( new Node
[count_
] ), dataIDs( new int[count_
] ), bounds( new T
[length_
][2] ) {
111 /** Copy constructor with moving semantics (destroys its argument) */
112 KDTree(KDTree
&other
)
113 : depth(other
.depth
), length(other
.length
), count(other
.count
)
114 , nodes(other
.nodes
), dataIDs(other
.dataIDs
), bounds(other
.bounds
) {
120 /** Takes an index of a "leaf" node (past the end of ::nodes)
121 * and returns the appropriate data ID */
122 int leafID2dataID(int leafID
) const {
123 ASSERT( count
<=leafID
&& leafID
<2*count
);
124 int index
= leafID
-powers
[depth
];
126 index
+= count
; // it is on the shallower side of the tree
127 ASSERT( 0<=index
&& index
<count
);
128 return dataIDs
[index
];
132 /** Only frees the memory */
140 /** Performs a nearest-neighbour search by managing a heap from nodes of a KDTree.
141 * It returns vectors (their indices) in the order of ascending distance (SE)
142 * from a given fixed point. It can compute a lower bound of the SEs of the remaining
143 * vectors at any time. */
145 /** One element of the ::heap representing a node in the KDTree ::kd */
147 int nodeIndex
; ///< Index of the node in ::kd
148 T
*data
; /**< Stores the SE and the coordinates of the nearest point
149 * (to this node's bounding box) */
150 /** No-init constructor */
152 /** Initializes the members from the parameters */
153 HeapNode(int nodeIndex_
,T
*data_
): nodeIndex(nodeIndex_
), data(data_
) {}
155 /** Returns reference to the SE of the nearest point to this node's bounding box */
156 T
& getSE() { return *data
; }
157 /** Returns the SE of the nearest point to this node's bounding box */
158 T
getSE() const { return *data
; }
159 /** Returns pointer to the nearest point to this node's bounding box */
160 T
* getNearest() { return data
+1; }
162 /** Defines the order of ::heap - ascending according to ::getSE */
164 bool operator()(const HeapNode
&a
,const HeapNode
&b
)
165 { return a
.getSE() > b
.getSE(); }
168 const KDTree
&kd
; ///< Reference to the KDTree we operate on
169 const T
* const point
; ///< Pointer to the point we are trying to approach
170 vector
<HeapNode
> heap
; ///< The current heap of the tree nodes
171 BulkAllocator
<T
> allocator
; ///< The allocator for HeapNode::data
173 /** Builds the heap from a KDTree \p tree and vector \p point_
174 * (they've got to remain valid until the destruction of this instance) */
175 PointHeap(const KDTree
&tree
,const T
*point_
,bool checkNaNs
)
176 : kd(tree
), point(point_
) {
178 // create the root heap-node
179 HeapNode
rootNode( 1, allocator
.makeField(kd
.length
+1) );
180 // compute the nearest point within the global bounds and corresponding SE
181 using namespace FieldMath
;
182 rootNode
.getSE()= checkNaNs
183 ? moveToBounds_copy
<T
,true> ( point
, kd
.bounds
, kd
.length
, rootNode
.getNearest() )
184 : moveToBounds_copy
<T
,false>( point
, kd
.bounds
, kd
.length
, rootNode
.getNearest() );
185 // push it onto the heap (and reserve more to speed up the first leaf-gettings)
186 heap
.reserve(kd
.depth
*2);
187 heap
.push_back(rootNode
);
190 /** Returns whether the heap is empty ( !isEmpty() is needed for all other methods) */
192 { return heap
.empty(); }
194 /** Returns the SE of the top node (always equals the SE of the next leaf) */
196 ASSERT( !isEmpty() );
197 return heap
[0].getSE();
200 /** Removes a leaf, returns the matching vector's index,
201 * assumes it's safe to discard nodes further than \p maxSE */
202 template<bool CheckNaNs
> int popLeaf(T maxSE
) {
203 // ensure a leaf is on the top of the heap and get its ID
204 makeTopLeaf
<CheckNaNs
>(maxSE
);
205 int result
= kd
.leafID2dataID( heap
.front().nodeIndex
);
206 // remove the top from the heap (no need to free the memory - ::allocator)
207 pop_heap( heap
.begin(), heap
.end(), HeapOrder() );
212 /** Divides the top nodes until there's a leaf on the top
213 * assumes it's safe to discard nodes further than \p maxSE */
214 template<bool CheckNaNs
> void makeTopLeaf(T maxSE
);
216 }; // PointHeap class
220 template<class T
> template<bool CheckNaNs
>
221 void KDTree
<T
>::PointHeap::makeTopLeaf(T maxSE
) {
222 ASSERT( !isEmpty() );
223 // exit if there's a leaf on the top already
224 if ( heap
[0].nodeIndex
>= kd
.count
)
226 PtrInt oldHeapSize
= heap
.size();
227 HeapNode heapRoot
= heap
[0]; // making a local working copy of the top of the heap
229 do { // while heapRoot isn't leaf ... while ( heapRoot.nodeIndex<kd.count )
230 const Node
&node
= kd
.nodes
[heapRoot
.nodeIndex
];
231 // now replace the node with its two children:
232 // one of them will have the same SE (replaces its parent on the top),
233 // the other one can have higher SE (push_back-ed on the heap)
235 bool validCoord
= !CheckNaNs
|| !isNaN(point
[node
.coord
]);
236 // the higher SE can be computed from the parent heap-node
238 bool goRight
; // will the heapRoot represent the right child or the left one?
240 Real oldDiff
= Real(point
[node
.coord
]) - heapRoot
.getNearest()[node
.coord
];
241 Real newDiff
= Real(point
[node
.coord
]) - node
.threshold
;
243 newSE
= heapRoot
.getSE() - sqr(oldDiff
) + sqr(newDiff
);
244 ASSERT( newSE
>= heapRoot
.getSE() );
246 newSE
= heapRoot
.getSE();
247 goRight
= false; // both boolean values are possible
249 // the root will represent its closer child - neither nearest point nor SE changes
250 heapRoot
.nodeIndex
= heapRoot
.nodeIndex
*2 + goRight
;
251 // if the new SE is too high, continue (omit the child with this big SE)
254 // create a new heap-node, allocate it's data and assign index of the other child
256 newHNode
.data
= allocator
.makeField(kd
.length
+1);
257 newHNode
.getSE()= newSE
;
258 newHNode
.nodeIndex
= heapRoot
.nodeIndex
-goRight
+!goRight
;
259 // the nearest point of the new heap-node only differs in one coordinate
260 FieldMath::assign( heapRoot
.getNearest(), kd
.length
, newHNode
.getNearest() );
262 newHNode
.getNearest()[node
.coord
]= node
.threshold
;
263 // add the new node to the back, restore the heap-property later
264 heap
.push_back(newHNode
);
266 } while ( heapRoot
.nodeIndex
< kd
.count
);
268 heap
[0]= heapRoot
; // restoring the working copy of the heap's top node
269 // restore the heap-property on the added nodes
270 typename vector
<HeapNode
>::iterator it
= heap
.begin()+oldHeapSize
;
272 push_heap( heap
.begin(), it
, HeapOrder() );
273 while ( it
++ != heap
.end() );
274 } // KDTree<T>::PointHeap::makeTopLeaf<CheckNaNs>() method
276 /** Derived type used to construct KDTree instances (::makeTree static method) */
277 template<class T
> class KDBuilder
: public KDTree
<T
> {
279 typedef KDTree
<T
> Tree
;
280 typedef T BoundsPair
[2];
281 typedef typename
Tree::Bounds Bounds
;
282 /** Type for a method that chooses which coordinate to split */
283 typedef int (KDBuilder::*CoordChooser
)
284 (int nodeIndex
,int *beginIDs
,int *endIDs
,int depthLeft
) const;
286 using Tree::depth
; using Tree::length
; using Tree::count
;
287 using Tree::nodes
; using Tree::dataIDs
; using Tree::bounds
;
290 const CoordChooser chooser
;
291 mutable Bounds chooserTmp
;
293 KDBuilder(const T
*data_
,int length
,int count
,CoordChooser chooser_
)
294 : Tree(length
,count
), data(data_
), chooser(chooser_
), chooserTmp(0) {
295 ASSERT( length
>0 && count
>0 && chooser
&& data
);
296 // create the index-vector, coumpute the bounding box, build the tree
297 for (int i
=0; i
<count
; ++i
)
301 buildNode(1,dataIDs
,dataIDs
+count
,depth
);
303 DEBUG_ONLY( chooserTmp
= 0; data
= 0; )
306 /** Creates bounds containing one value */
308 void operator()(const T
&val
,BoundsPair
&bounds
) const
309 { bounds
[0]= bounds
[1]= val
; }
311 /** Expands a valid bounding box to contain a point (one coordinate at once) */
312 struct BoundsExpander
{
313 void operator()(const T
&val
,BoundsPair
&bounds
) const {
314 if ( val
< bounds
[0] ) // lower bound
316 if ( val
> bounds
[1] ) // upper bound
320 /** Computes the bounding box of ::count vectors with length ::length stored in ::data.
321 * The vectors in ::data are stored linearly, \p boundsRes should be preallocated */
322 void getBounds(Bounds boundsRes
) const {
323 using namespace FieldMath
;
325 // make the initial bounds only contain the first point
326 transform2(data
,data
+length
,boundsRes
,NewBounds());
327 int count
= Tree::count
;
328 const T
*nowData
= data
;
329 // expand the bounds by every point (except for the first one)
332 transform2( nowData
, nowData
+length
, boundsRes
, BoundsExpander() );
335 /** Like ::getBounds, but it only works on vectors with indices from [\p beginIDs;\p endIDs)
336 * instead of [0,::count), \p boundsRes should be preallocated to store the result */
337 void getBounds(const int *beginIDs
,const int *endIDs
,Bounds boundsRes
) const {
338 using namespace FieldMath
;
339 ASSERT(endIDs
>beginIDs
);
340 // make the initial bounds only contain the first point
341 const T
*begin
= data
+ *beginIDs
*length
;
342 transform2(begin
,begin
+length
,boundsRes
,NewBounds());
343 // expand the bounds by every point (except for the first one)
344 while ( ++beginIDs
!= endIDs
) {
345 begin
= data
+ *beginIDs
*length
;
346 transform2( begin
, begin
+length
, boundsRes
, BoundsExpander() );
350 /** Recursively builds node \p nodeIndex and its subtree of depth \p depthLeft
351 * (including leaves), operates on data \p data in the range [\p beginIDs,\p endIDs) */
352 void buildNode(int nodeIndex
,int *beginIDs
,int *endIDs
,int depthLeft
);
355 /** Builds a KDTree from \p count vectors of length \p length stored in \p data,
356 * splitting the nodes by \p chooser CoordChooser */
357 static Tree
* makeTree(const T
*data
,int length
,int count
,CoordChooser chooser
) {
358 KDBuilder
builder(data
,length
,count
,chooser
);
359 // moving only the necesarry data (pointers) into a new copy
360 return new Tree(builder
);
363 /** CoordChooser choosing the longest coordinate of the bounding box of the current interval */
364 int choosePrecise(int nodeIndex
,int *beginIDs
,int *endIDs
,int /*depthLeft*/) const;
365 /** CoordChooser choosing the coordinate only according to the depth */
366 int chooseFast(int /*nodeIndex*/,int* /*beginIDs*/,int* /*endIDs*/,int depthLeft
) const
367 { return depthLeft
%length
; }
368 /** CoordChooser choosing a random coordinate */
369 int chooseRand(int /*nodeIndex*/,int* /*beginIDs*/,int* /*endIDs*/,int /*depthLeft*/) const
370 { return rand()%length
; }
371 /** CoordChooser - like ::choosePrecise, but doesn't compute the real bounding box,
372 * only approximates it by splitting the parent's box (a little less accurate, but much faster) */
373 int chooseApprox(int nodeIndex
,int* /*beginIDs*/,int* /*endIDs*/,int depthLeft
) const;
374 }; // KDBuilder class
379 /** Finds the longest coordinate of a bounding box, only to be used in for_each calls */
380 template<class T
> struct MaxDiffCoord
{
381 typedef typename KDBuilder
<T
>::BoundsPair BoundsPair
;
383 T maxDiff
; ///< the maximal difference value
384 int bestIndex
/// the index where ::maxDiff occured
385 , nextIndex
; ///< next index to be checked
387 /** Initializes from the 0-th index value */
388 MaxDiffCoord(const BoundsPair
& bounds0
)
389 : maxDiff(bounds0
[1]-bounds0
[0]), bestIndex(0), nextIndex(1) {}
391 /** To be called successively for indices from 1-st on */
392 void operator()(const BoundsPair
& bounds_i
) {
393 T diff
= bounds_i
[1]-bounds_i
[0];
395 bestIndex
= nextIndex
;
402 template<class T
> int KDBuilder
<T
>
403 ::choosePrecise(int nodeIndex
,int *beginIDs
,int *endIDs
,int) const {
404 ASSERT( nodeIndex
>0 && beginIDs
&& endIDs
&& beginIDs
<endIDs
);
405 // temporary storage for computed bounding box
406 BoundsPair boundsStorage
[length
];
407 const BoundsPair
*localBounds
;
408 // find out the bounding box
409 if ( nodeIndex
>1 ) { // compute the bounding box
410 localBounds
= boundsStorage
;
411 getBounds( beginIDs
, endIDs
, boundsStorage
);
412 } else // we are in the root -> we can use already computed bounds
413 localBounds
= this->bounds
;
414 // find and return the longest coordinate
415 MaxDiffCoord
<T
> mdc
= for_each
416 ( localBounds
+1, localBounds
+length
, MaxDiffCoord
<T
>(localBounds
[0]) );
417 ASSERT( mdc
.nextIndex
== length
);
418 return mdc
.bestIndex
;
421 template<class T
> int KDBuilder
<T
>::chooseApprox(int nodeIndex
,int*,int*,int) const {
422 using namespace FieldMath
;
425 int myDepth
= log2ceil(nodeIndex
+1)-1;
426 if (!myDepth
) { // I'm in the root - copy the bounds
427 ASSERT(nodeIndex
==1);
428 chooserTmp
= new BoundsPair
[length
*(depth
+1)]; // allocate my temporary bound-array
429 assign( bounds
, length
, chooserTmp
);
432 Bounds myBounds
= chooserTmp
+length
*myDepth
;
433 if (myDepth
) { // I'm not the root - copy parent's bounds and modify them
434 const typename
Tree::Node
&parent
= nodes
[nodeIndex
/2];
435 Bounds parentBounds
= myBounds
-length
;
436 if (nodeIndex
%2) { // I'm the right son -> bounds on this level not initialized
437 assign( parentBounds
, length
, myBounds
); // copying parent bounds
438 myBounds
[parent
.coord
][0]= parent
.threshold
; // adjusting the lower bound
439 } else // I'm the left son
440 if ( nodeIndex
+1 < count
) { // almost the same as brother -> only adjust the coordinate
441 myBounds
[parent
.coord
][0]= parentBounds
[parent
.coord
][0];
442 myBounds
[parent
.coord
][1]= parent
.threshold
;
443 } else { // I've got no brother
444 ASSERT( nodeIndex
+1 == count
);
445 assign( parentBounds
, length
, myBounds
);
446 myBounds
[parent
.coord
][1]= parent
.threshold
;
449 // find out the widest dimension
450 MaxDiffCoord
<T
> mdc
= for_each( myBounds
+1, myBounds
+length
, MaxDiffCoord
<T
>(myBounds
[0]) );
451 ASSERT( mdc
.nextIndex
== length
);
452 return mdc
.bestIndex
;
456 /** Compares vectors (given by their indices) according to a given coordinate */
457 template<class T
> class IndexComparator
{
458 const T
*data
; ///< pointer to the right coordinate of the first vector (of index 0)
459 int length
; ///< length of the vectors in ::data
461 IndexComparator(const T
*data_
,int length_
,int index_
)
462 : data(data_
+index_
), length(length_
) {}
463 bool operator()(int a
,int b
) const
464 { return data
[a
*length
] < data
[b
*length
]; }
467 template<class T
> void KDBuilder
<T
>
468 ::buildNode(int nodeIndex
,int *beginIDs
,int *endIDs
,int depthLeft
) {
469 int count
= endIDs
-beginIDs
; // owershadowing Tree::count
470 // check we've got at least one vector and the depth&count are adequate to each other
471 ASSERT( count
>=2 && powers
[depthLeft
-1]<count
&& count
<=powers
[depthLeft
] );
473 // find out where to split - how many items should be on the left to have the heap-shape
474 bool shallowRight
= ( count
<= powers
[depthLeft
]+powers
[depthLeft
-1] );
475 int *middle
= shallowRight
476 ? endIDs
-powers
[depthLeft
-1]
477 : beginIDs
+powers
[depthLeft
];
478 // find out the dividing coordinate and find the "median" in this coordinate
479 int coord
= (this->*chooser
)(nodeIndex
,beginIDs
,endIDs
,depthLeft
);
480 nth_element( beginIDs
, middle
, endIDs
, IndexComparator
<T
>(data
,length
,coord
) );
481 // fill the node's data (dividing coordinate and its threshold)
482 nodes
[nodeIndex
].coord
= coord
;
483 nodes
[nodeIndex
].threshold
= data
[*middle
*length
+coord
]; // min. value of the right son
484 // recurse on both halves (if needed; fall-through switch)
486 default: // we've got enough nodes - build both subtrees (fall through)
487 // build the right subtree
488 buildNode( 2*nodeIndex
+1, middle
, endIDs
, depthLeft
-shallowRight
);
489 case 3: // only a pair in the first half
490 // build the left subtree
491 buildNode( 2*nodeIndex
, beginIDs
, middle
, depthLeft
);
492 case 2: // nothing needs to be sorted
497 #endif // KDTREE_HEADER_