Subversion Repositories gelsvn

Rev

Rev 330 | Rev 595 | Go to most recent revision | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
61 jab 1
#ifndef __KDTREE_H
2
#define __KDTREE_H
3
 
4
#include <cmath>
5
#include <iostream>
6
#include <vector>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
125 jab 9
#include "CGLA/ArithVec.h"
61 jab 10
 
198 bj 11
#if (_MSC_VER >= 1200)
203 jrf 12
#pragma warning (push)
198 bj 13
#pragma warning (disable: 4018)
14
#endif
15
 
61 jab 16
namespace Geometry
17
{
89 jab 18
	/** \brief A classic K-D tree. 
19
 
61 jab 20
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
24
	class KDTree
25
	{
26
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
30
 
31
		/// KDNode struct represents node in KD tree
32
		struct KDNode
33
		{
34
			KeyT key;
35
			ValT val;
36
			short dsc;
37
 
38
			KDNode(): dsc(0) {}
39
 
40
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
42
 
43
			ScalarType dist(const KeyType& p) const 
44
			{
45
				KeyType dist_vec = p;
46
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
48
			}
49
		};
50
 
51
		typedef std::vector<KDNode> NodeVecType;
330 jab 52
		bool is_built;
61 jab 53
		NodeVecType init_nodes;
54
		NodeVecType nodes;
55
 
56
		/** Comp is a class used for comparing two keys. Comp is constructed
57
				with the discriminator - i.e. the coordinate of the key that is used
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
		class Comp
60
		{
61
			const int dsc;
62
		public:
63
			Comp(int _dsc): dsc(_dsc) {}
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			{
66
				int dim=KeyType::get_dim();
67
				for(int i=0;i<dim;i++)
68
					{
69
						int j=(dsc+i)%dim;
70
						if(k0[j]<k1[j])
71
							return true;
72
						if(k0[j]>k1[j])
73
							return false;
74
					}
75
				return false;
76
			}
77
 
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			{
80
				return (*this)(k0.key,k1.key);
81
			}
82
		};
83
 
84
 
85
		/** Passed a vector of keys, this function will construct an optimal tree.
330 jab 86
				It is called recursively */
87
		void optimize(int, int, int);
61 jab 88
 
89
		/** Finde nearest neighbour. */
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
91
 
92
 
93
		void in_sphere_priv(int n, 
94
												const KeyType& p, 
95
												const ScalarType& dist,
96
												std::vector<KeyT>& keys,
97
												std::vector<ValT>& vals) const;
98
 
99
		/** Finds the optimal discriminator. There are more ways, but this 
100
				function traverses the vector and finds out what dimension has
101
				the greatest difference between min and max element. That dimension
102
				is used for discriminator */
103
		int opt_disc(int,int) const;
104
 
105
	public:
106
 
107
		/** Build tree from vector of keys passed as argument. */
330 jab 108
		KDTree(): is_built(false), init_nodes(1) {}
61 jab 109
 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
111
				be built - by calling the build function - before you can search. */
112
		void insert(const KeyT& key, const ValT& val)
113
		{
330 jab 114
				if(is_built)
115
				{
116
						assert(init_nodes.size()==1);
117
						init_nodes.swap(nodes);
118
						is_built=false;
119
				}
120
				init_nodes.push_back(KDNode(key,val));
61 jab 121
		}
122
 
330 jab 123
		/** Build the tree. After this function has been called, it is no longer 
61 jab 124
				legal to insert elements, but you can perform searches. */
125
		void build()
126
		{
127
			assert(!is_built);
330 jab 128
			nodes.resize(init_nodes.size());
129
			if(init_nodes.size() > 1)	
130
				optimize(1,1,init_nodes.size());
131
			NodeVecType v(1);
61 jab 132
			init_nodes.swap(v);
133
			is_built = true;
134
		}
135
 
136
		/** Find the key value pair closest to the key given as first 
334 jab 137
				argument. The second argument is the maximum search distance. Upon
138
				return this value is changed to the distance to the found point.
61 jab 139
				The final two arguments contain the closest key and its 
140
				associated value upon return. */
324 jab 141
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
61 jab 142
		{
143
			assert(is_built);
330 jab 144
			if(nodes.size()>1)
145
			{
146
					ScalarType max_sq_dist = CGLA::sqr(dist);
147
					if(int n = closest_point_priv(1, p, max_sq_dist))
148
					{
149
							k = nodes[n].key;
150
							v = nodes[n].val;
151
							dist = std::sqrt(max_sq_dist);
152
							return true;
153
					}
154
			}
61 jab 155
			return false;
156
		}
157
 
158
		/** Find all the elements within a given radius (second argument) of
159
				the key (first argument). The key value pairs inside the sphere are
330 jab 160
				returned in a pair of vectors passed as the two last arguments.
161
				Note that we don't resize the two last arguments to zero - so either
162
				they should be empty vectors or you should desire appending the newly
163
				found elements onto these vectors.				
164
		*/
61 jab 165
		int in_sphere(const KeyType& p, 
324 jab 166
									ScalarType dist,
61 jab 167
									std::vector<KeyT>& keys,
168
									std::vector<ValT>& vals) const
169
		{
170
			assert(is_built);
330 jab 171
			if(nodes.size()>1)
172
			{
173
					ScalarType max_sq_dist = CGLA::sqr(dist);
174
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
175
					return keys.size();
176
			}
177
			return 0;
61 jab 178
		}
179
 
180
 
181
	};
182
 
183
	template<class KeyT, class ValT>
184
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
185
																	int kvec_end) const 
186
	{
187
		KeyType vmin = init_nodes[kvec_beg].key;
188
		KeyType vmax = init_nodes[kvec_beg].key;
189
		for(int i=kvec_beg;i<kvec_end;i++)
190
			{
191
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
192
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
193
			}
194
		int od=0;
195
		KeyType ave_v = vmax-vmin;
196
		for(int i=1;i<KeyType::get_dim();i++)
197
			if(ave_v[i]>ave_v[od]) od = i;
198
		return od;
199
	} 
200
 
201
	template<class KeyT, class ValT>
202
	void KDTree<KeyT,ValT>::optimize(int cur,
203
																	 int kvec_beg,  
330 jab 204
																	 int kvec_end)
61 jab 205
	{
206
		// Assert that we are not inserting beyond capacity.
207
		assert(cur < nodes.size());
208
 
209
		// If there is just a single element, we simply insert.
210
		if(kvec_beg+1==kvec_end) 
211
			{
212
				nodes[cur] = init_nodes[kvec_beg];
213
				nodes[cur].dsc = -1;
214
				return;
215
			}
216
 
217
		// Find the axis that best separates the data.
218
		int disc = opt_disc(kvec_beg, kvec_end);
219
 
220
		// Compute the median element. See my document on how to do this
221
		// www.imm.dtu.dk/~jab/publications.html
222
		int N = kvec_end-kvec_beg;
223
		int M = 1<< (CGLA::two_to_what_power(N));
224
		int R = N-(M-1);
225
		int left_size  = (M-2)/2;
226
		int right_size = (M-2)/2;
227
		if(R < M/2)
228
			{
229
				left_size += R;
230
			}
231
		else
232
			{
233
				left_size += M/2;
234
				right_size += R-M/2;
235
			}
236
 
237
		int median = kvec_beg + left_size;
238
 
239
		// Sort elements but use nth_element (which is cheaper) than
240
		// a sorting algorithm. All elements to the left of the median
241
		// will be smaller than or equal the median. All elements to the right
242
		// will be greater than or equal to the median.
243
		const Comp comp(disc);
244
		std::nth_element(&init_nodes[kvec_beg], 
245
										 &init_nodes[median], 
246
										 &init_nodes[kvec_end], comp);
247
 
248
		// Insert the node in the final data structure.
249
		nodes[cur] = init_nodes[median];
250
		nodes[cur].dsc = disc;
251
 
252
		// Recursively build left and right tree.
253
		if(left_size>0)	
330 jab 254
			optimize(2*cur, kvec_beg, median);
61 jab 255
 
256
		if(right_size>0) 
330 jab 257
			optimize(2*cur+1, median+1, kvec_end);
61 jab 258
	}
259
 
260
	template<class KeyT, class ValT>
261
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
262
																						ScalarType& dist) const
263
	{
264
		int ret_node = 0;
265
		ScalarType this_dist = nodes[n].dist(p);
266
 
267
		if(this_dist<dist)
268
			{
269
				dist = this_dist;
270
				ret_node = n;
271
			}
272
		if(nodes[n].dsc != -1)
273
			{
274
				int dsc         = nodes[n].dsc;
324 jab 275
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
61 jab 276
				bool left_son   = Comp(dsc)(p,nodes[n].key);
277
 
278
				if(left_son||dsc_dist<dist)
279
					{
280
						int left_child = 2*n;
281
						if(left_child < nodes.size())
282
							if(int nl=closest_point_priv(left_child, p, dist))
283
								ret_node = nl;
284
					}
285
				if(!left_son||dsc_dist<dist)
286
					{
287
						int right_child = 2*n+1;
288
						if(right_child < nodes.size())
289
							if(int nr=closest_point_priv(right_child, p, dist))
290
								ret_node = nr;
291
					}
292
			}
293
		return ret_node;
294
	}
295
 
296
	template<class KeyT, class ValT>
297
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
298
																				 const KeyType& p, 
299
																				 const ScalarType& dist,
300
																				 std::vector<KeyT>& keys,
301
																				 std::vector<ValT>& vals) const
302
	{
303
		ScalarType this_dist = nodes[n].dist(p);
304
		assert(n<nodes.size());
305
		if(this_dist<dist)
306
			{
307
				keys.push_back(nodes[n].key);
308
				vals.push_back(nodes[n].val);
309
			}
310
		if(nodes[n].dsc != -1)
311
			{
312
				const int dsc         = nodes[n].dsc;
324 jab 313
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
61 jab 314
 
315
				bool left_son = Comp(dsc)(p,nodes[n].key);
316
 
317
				if(left_son||dsc_dist<dist)
318
					{
319
						int left_child = 2*n;
320
						if(left_child < nodes.size())
321
							in_sphere_priv(left_child, p, dist, keys, vals);
322
					}
323
				if(!left_son||dsc_dist<dist)
324
					{
325
						int right_child = 2*n+1;
326
						if(right_child < nodes.size())
327
							in_sphere_priv(right_child, p, dist, keys, vals);
328
					}
329
			}
330
	}
331
}
332
namespace GEO = Geometry;
333
 
198 bj 334
#if (_MSC_VER >= 1200)
203 jrf 335
#pragma warning (pop)
61 jab 336
#endif
198 bj 337
 
338
 
339
#endif