Subversion Repositories gelsvn

Rev

Rev 443 | Go to most recent revision | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
595 jab 1
/* ----------------------------------------------------------------------- *
2
 * This file is part of GEL, http://www.imm.dtu.dk/GEL
3
 * Copyright (C) the authors and DTU Informatics
4
 * For license and list of authors, see ../../doc/intro.pdf
5
 * ----------------------------------------------------------------------- */
6
 
7
/**
8
 * @file KDTree.h
9
 * @brief KD Tree implementation based on a binary heap.
10
 */
11
 
443 jab 12
#ifndef __GEOMETRY_KDTREE_H
13
#define __GEOMETRY_KDTREE_H
61 jab 14
 
15
#include <cmath>
16
#include <iostream>
17
#include <vector>
18
#include <algorithm>
19
#include "CGLA/CGLA.h"
125 jab 20
#include "CGLA/ArithVec.h"
61 jab 21
 
198 bj 22
#if (_MSC_VER >= 1200)
203 jrf 23
#pragma warning (push)
198 bj 24
#pragma warning (disable: 4018)
25
#endif
26
 
61 jab 27
namespace Geometry
28
{
89 jab 29
	/** \brief A classic K-D tree. 
30
 
61 jab 31
			A K-D tree is a good data structure for storing points in space
32
			and for nearest neighbour queries. It is basically a generalized 
33
			binary tree in K dimensions. */
34
	template<class KeyT, class ValT>
35
	class KDTree
36
	{
37
		typedef typename KeyT::ScalarType ScalarType;
38
		typedef KeyT KeyType;
39
		typedef std::vector<KeyT> KeyVectorType;
40
		typedef std::vector<ValT> ValVectorType;
41
 
42
		/// KDNode struct represents node in KD tree
43
		struct KDNode
44
		{
45
			KeyT key;
46
			ValT val;
47
			short dsc;
48
 
49
			KDNode(): dsc(0) {}
50
 
51
			KDNode(const KeyT& _key, const ValT& _val):
52
				key(_key), val(_val), dsc(-1) {}
53
 
54
			ScalarType dist(const KeyType& p) const 
55
			{
56
				KeyType dist_vec = p;
57
				dist_vec  -= key;
58
				return dot(dist_vec, dist_vec);
59
			}
60
		};
61
 
62
		typedef std::vector<KDNode> NodeVecType;
330 jab 63
		bool is_built;
61 jab 64
		NodeVecType init_nodes;
65
		NodeVecType nodes;
66
 
67
		/** Comp is a class used for comparing two keys. Comp is constructed
68
				with the discriminator - i.e. the coordinate of the key that is used
69
				for comparing keys - Comp objects are passed to the sort algorithm.*/
70
		class Comp
71
		{
72
			const int dsc;
73
		public:
74
			Comp(int _dsc): dsc(_dsc) {}
75
			bool operator()(const KeyType& k0, const KeyType& k1) const
76
			{
77
				int dim=KeyType::get_dim();
78
				for(int i=0;i<dim;i++)
79
					{
80
						int j=(dsc+i)%dim;
81
						if(k0[j]<k1[j])
82
							return true;
83
						if(k0[j]>k1[j])
84
							return false;
85
					}
86
				return false;
87
			}
88
 
89
			bool operator()(const KDNode& k0, const KDNode& k1) const
90
			{
91
				return (*this)(k0.key,k1.key);
92
			}
93
		};
94
 
95
 
96
		/** Passed a vector of keys, this function will construct an optimal tree.
330 jab 97
				It is called recursively */
98
		void optimize(int, int, int);
61 jab 99
 
100
		/** Finde nearest neighbour. */
101
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
102
 
103
 
104
		void in_sphere_priv(int n, 
105
												const KeyType& p, 
106
												const ScalarType& dist,
107
												std::vector<KeyT>& keys,
108
												std::vector<ValT>& vals) const;
109
 
110
		/** Finds the optimal discriminator. There are more ways, but this 
111
				function traverses the vector and finds out what dimension has
112
				the greatest difference between min and max element. That dimension
113
				is used for discriminator */
114
		int opt_disc(int,int) const;
115
 
116
	public:
117
 
118
		/** Build tree from vector of keys passed as argument. */
330 jab 119
		KDTree(): is_built(false), init_nodes(1) {}
61 jab 120
 
121
		/** Insert a key value pair into the tree. Note that the tree needs to 
122
				be built - by calling the build function - before you can search. */
123
		void insert(const KeyT& key, const ValT& val)
124
		{
330 jab 125
				if(is_built)
126
				{
127
						assert(init_nodes.size()==1);
128
						init_nodes.swap(nodes);
129
						is_built=false;
130
				}
131
				init_nodes.push_back(KDNode(key,val));
61 jab 132
		}
133
 
330 jab 134
		/** Build the tree. After this function has been called, it is no longer 
61 jab 135
				legal to insert elements, but you can perform searches. */
136
		void build()
137
		{
138
			assert(!is_built);
330 jab 139
			nodes.resize(init_nodes.size());
140
			if(init_nodes.size() > 1)	
141
				optimize(1,1,init_nodes.size());
142
			NodeVecType v(1);
61 jab 143
			init_nodes.swap(v);
144
			is_built = true;
145
		}
146
 
147
		/** Find the key value pair closest to the key given as first 
334 jab 148
				argument. The second argument is the maximum search distance. Upon
149
				return this value is changed to the distance to the found point.
61 jab 150
				The final two arguments contain the closest key and its 
151
				associated value upon return. */
324 jab 152
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
61 jab 153
		{
154
			assert(is_built);
330 jab 155
			if(nodes.size()>1)
156
			{
157
					ScalarType max_sq_dist = CGLA::sqr(dist);
158
					if(int n = closest_point_priv(1, p, max_sq_dist))
159
					{
160
							k = nodes[n].key;
161
							v = nodes[n].val;
162
							dist = std::sqrt(max_sq_dist);
163
							return true;
164
					}
165
			}
61 jab 166
			return false;
167
		}
168
 
169
		/** Find all the elements within a given radius (second argument) of
170
				the key (first argument). The key value pairs inside the sphere are
330 jab 171
				returned in a pair of vectors passed as the two last arguments.
172
				Note that we don't resize the two last arguments to zero - so either
173
				they should be empty vectors or you should desire appending the newly
174
				found elements onto these vectors.				
175
		*/
61 jab 176
		int in_sphere(const KeyType& p, 
324 jab 177
									ScalarType dist,
61 jab 178
									std::vector<KeyT>& keys,
179
									std::vector<ValT>& vals) const
180
		{
181
			assert(is_built);
330 jab 182
			if(nodes.size()>1)
183
			{
184
					ScalarType max_sq_dist = CGLA::sqr(dist);
185
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
186
					return keys.size();
187
			}
188
			return 0;
61 jab 189
		}
190
 
191
 
192
	};
193
 
194
	template<class KeyT, class ValT>
195
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
196
																	int kvec_end) const 
197
	{
198
		KeyType vmin = init_nodes[kvec_beg].key;
199
		KeyType vmax = init_nodes[kvec_beg].key;
200
		for(int i=kvec_beg;i<kvec_end;i++)
201
			{
202
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
203
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
204
			}
205
		int od=0;
206
		KeyType ave_v = vmax-vmin;
207
		for(int i=1;i<KeyType::get_dim();i++)
208
			if(ave_v[i]>ave_v[od]) od = i;
209
		return od;
210
	} 
211
 
212
	template<class KeyT, class ValT>
213
	void KDTree<KeyT,ValT>::optimize(int cur,
214
																	 int kvec_beg,  
330 jab 215
																	 int kvec_end)
61 jab 216
	{
217
		// Assert that we are not inserting beyond capacity.
218
		assert(cur < nodes.size());
219
 
220
		// If there is just a single element, we simply insert.
221
		if(kvec_beg+1==kvec_end) 
222
			{
223
				nodes[cur] = init_nodes[kvec_beg];
224
				nodes[cur].dsc = -1;
225
				return;
226
			}
227
 
228
		// Find the axis that best separates the data.
229
		int disc = opt_disc(kvec_beg, kvec_end);
230
 
231
		// Compute the median element. See my document on how to do this
232
		// www.imm.dtu.dk/~jab/publications.html
233
		int N = kvec_end-kvec_beg;
234
		int M = 1<< (CGLA::two_to_what_power(N));
235
		int R = N-(M-1);
236
		int left_size  = (M-2)/2;
237
		int right_size = (M-2)/2;
238
		if(R < M/2)
239
			{
240
				left_size += R;
241
			}
242
		else
243
			{
244
				left_size += M/2;
245
				right_size += R-M/2;
246
			}
247
 
248
		int median = kvec_beg + left_size;
249
 
250
		// Sort elements but use nth_element (which is cheaper) than
251
		// a sorting algorithm. All elements to the left of the median
252
		// will be smaller than or equal the median. All elements to the right
253
		// will be greater than or equal to the median.
254
		const Comp comp(disc);
255
		std::nth_element(&init_nodes[kvec_beg], 
256
										 &init_nodes[median], 
257
										 &init_nodes[kvec_end], comp);
258
 
259
		// Insert the node in the final data structure.
260
		nodes[cur] = init_nodes[median];
261
		nodes[cur].dsc = disc;
262
 
263
		// Recursively build left and right tree.
264
		if(left_size>0)	
330 jab 265
			optimize(2*cur, kvec_beg, median);
61 jab 266
 
267
		if(right_size>0) 
330 jab 268
			optimize(2*cur+1, median+1, kvec_end);
61 jab 269
	}
270
 
271
	template<class KeyT, class ValT>
272
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
273
																						ScalarType& dist) const
274
	{
275
		int ret_node = 0;
276
		ScalarType this_dist = nodes[n].dist(p);
277
 
278
		if(this_dist<dist)
279
			{
280
				dist = this_dist;
281
				ret_node = n;
282
			}
283
		if(nodes[n].dsc != -1)
284
			{
285
				int dsc         = nodes[n].dsc;
324 jab 286
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
61 jab 287
				bool left_son   = Comp(dsc)(p,nodes[n].key);
288
 
289
				if(left_son||dsc_dist<dist)
290
					{
291
						int left_child = 2*n;
292
						if(left_child < nodes.size())
293
							if(int nl=closest_point_priv(left_child, p, dist))
294
								ret_node = nl;
295
					}
296
				if(!left_son||dsc_dist<dist)
297
					{
298
						int right_child = 2*n+1;
299
						if(right_child < nodes.size())
300
							if(int nr=closest_point_priv(right_child, p, dist))
301
								ret_node = nr;
302
					}
303
			}
304
		return ret_node;
305
	}
306
 
307
	template<class KeyT, class ValT>
308
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
309
																				 const KeyType& p, 
310
																				 const ScalarType& dist,
311
																				 std::vector<KeyT>& keys,
312
																				 std::vector<ValT>& vals) const
313
	{
314
		ScalarType this_dist = nodes[n].dist(p);
315
		assert(n<nodes.size());
316
		if(this_dist<dist)
317
			{
318
				keys.push_back(nodes[n].key);
319
				vals.push_back(nodes[n].val);
320
			}
321
		if(nodes[n].dsc != -1)
322
			{
323
				const int dsc         = nodes[n].dsc;
324 jab 324
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
61 jab 325
 
326
				bool left_son = Comp(dsc)(p,nodes[n].key);
327
 
328
				if(left_son||dsc_dist<dist)
329
					{
330
						int left_child = 2*n;
331
						if(left_child < nodes.size())
332
							in_sphere_priv(left_child, p, dist, keys, vals);
333
					}
334
				if(!left_son||dsc_dist<dist)
335
					{
336
						int right_child = 2*n+1;
337
						if(right_child < nodes.size())
338
							in_sphere_priv(right_child, p, dist, keys, vals);
339
					}
340
			}
341
	}
342
}
343
namespace GEO = Geometry;
344
 
198 bj 345
#if (_MSC_VER >= 1200)
203 jrf 346
#pragma warning (pop)
61 jab 347
#endif
198 bj 348
 
349
 
350
#endif