Subversion Repositories gelsvn

Rev

Rev 334 | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 334 Rev 393
1
#ifndef __KDTREE_H
1
#ifndef __KDTREE_H
2
#define __KDTREE_H
2
#define __KDTREE_H
3
 
3
 
4
#include <cmath>
4
#include <cmath>
5
#include <iostream>
5
#include <iostream>
6
#include <vector>
6
#include <vector>
7
#include <algorithm>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
8
#include "CGLA/CGLA.h"
9
#include "CGLA/ArithVec.h"
9
#include "CGLA/ArithVec.h"
10
 
10
 
11
#if (_MSC_VER >= 1200)
11
#if (_MSC_VER >= 1200)
12
#pragma warning (push)
12
#pragma warning (push)
13
#pragma warning (disable: 4018)
13
#pragma warning (disable: 4018)
14
#endif
14
#endif
15
 
15
 
16
namespace Geometry
16
namespace Geometry
17
{
17
{
18
	/** \brief A classic K-D tree. 
18
	/** \brief A classic K-D tree. 
19
 
19
 
20
			A K-D tree is a good data structure for storing points in space
20
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
21
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
22
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
23
	template<class KeyT, class ValT>
24
	class KDTree
24
	class KDTree
25
	{
25
	{
26
		typedef typename KeyT::ScalarType ScalarType;
26
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
27
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
28
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
29
		typedef std::vector<ValT> ValVectorType;
30
	
30
	
31
		/// KDNode struct represents node in KD tree
31
		/// KDNode struct represents node in KD tree
32
		struct KDNode
32
		struct KDNode
33
		{
33
		{
34
			KeyT key;
34
			KeyT key;
35
			ValT val;
35
			ValT val;
36
			short dsc;
36
			short dsc;
37
 
37
 
38
			KDNode(): dsc(0) {}
38
			KDNode(): dsc(0) {}
39
 
39
 
40
			KDNode(const KeyT& _key, const ValT& _val):
40
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
41
				key(_key), val(_val), dsc(-1) {}
42
 
42
 
43
			ScalarType dist(const KeyType& p) const 
43
			ScalarType dist(const KeyType& p) const 
44
			{
44
			{
45
				KeyType dist_vec = p;
45
				KeyType dist_vec = p;
46
				dist_vec  -= key;
46
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
47
				return dot(dist_vec, dist_vec);
48
			}
48
			}
49
		};
49
		};
50
 
50
 
51
		typedef std::vector<KDNode> NodeVecType;
51
		typedef std::vector<KDNode> NodeVecType;
52
		bool is_built;
52
		bool is_built;
53
		NodeVecType init_nodes;
53
		NodeVecType init_nodes;
54
		NodeVecType nodes;
54
		NodeVecType nodes;
55
 
55
 
56
		/** Comp is a class used for comparing two keys. Comp is constructed
56
		/** Comp is a class used for comparing two keys. Comp is constructed
57
				with the discriminator - i.e. the coordinate of the key that is used
57
				with the discriminator - i.e. the coordinate of the key that is used
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
		class Comp
59
		class Comp
60
		{
60
		{
61
			const int dsc;
61
			const int dsc;
62
		public:
62
		public:
63
			Comp(int _dsc): dsc(_dsc) {}
63
			Comp(int _dsc): dsc(_dsc) {}
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			{
65
			{
66
				int dim=KeyType::get_dim();
66
				int dim=KeyType::get_dim();
67
				for(int i=0;i<dim;i++)
67
				for(int i=0;i<dim;i++)
68
					{
68
					{
69
						int j=(dsc+i)%dim;
69
						int j=(dsc+i)%dim;
70
						if(k0[j]<k1[j])
70
						if(k0[j]<k1[j])
71
							return true;
71
							return true;
72
						if(k0[j]>k1[j])
72
						if(k0[j]>k1[j])
73
							return false;
73
							return false;
74
					}
74
					}
75
				return false;
75
				return false;
76
			}
76
			}
77
 
77
 
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			{
79
			{
80
				return (*this)(k0.key,k1.key);
80
				return (*this)(k0.key,k1.key);
81
			}
81
			}
82
		};
82
		};
83
 
83
 
84
 
84
 
85
		/** Passed a vector of keys, this function will construct an optimal tree.
85
		/** Passed a vector of keys, this function will construct an optimal tree.
86
				It is called recursively */
86
				It is called recursively */
87
		void optimize(int, int, int);
87
		void optimize(int, int, int);
88
 
88
 
89
		/** Finde nearest neighbour. */
89
		/** Finde nearest neighbour. */
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
91
 
91
 
92
							
92
							
93
		void in_sphere_priv(int n, 
93
		void in_sphere_priv(int n, 
94
												const KeyType& p, 
94
												const KeyType& p, 
95
												const ScalarType& dist,
95
												const ScalarType& dist,
96
												std::vector<KeyT>& keys,
96
												std::vector<KeyT>& keys,
97
												std::vector<ValT>& vals) const;
97
												std::vector<ValT>& vals) const;
98
	
98
	
99
		/** Finds the optimal discriminator. There are more ways, but this 
99
		/** Finds the optimal discriminator. There are more ways, but this 
100
				function traverses the vector and finds out what dimension has
100
				function traverses the vector and finds out what dimension has
101
				the greatest difference between min and max element. That dimension
101
				the greatest difference between min and max element. That dimension
102
				is used for discriminator */
102
				is used for discriminator */
103
		int opt_disc(int,int) const;
103
		int opt_disc(int,int) const;
104
	
104
	
105
	public:
105
	public:
106
 
106
 
107
		/** Build tree from vector of keys passed as argument. */
107
		/** Build tree from vector of keys passed as argument. */
108
		KDTree(): is_built(false), init_nodes(1) {}
108
		KDTree(): is_built(false), init_nodes(1) {}
109
 
109
 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
111
				be built - by calling the build function - before you can search. */
111
				be built - by calling the build function - before you can search. */
112
		void insert(const KeyT& key, const ValT& val)
112
		void insert(const KeyT& key, const ValT& val)
113
		{
113
		{
114
				if(is_built)
114
				if(is_built)
115
				{
115
				{
116
						assert(init_nodes.size()==1);
116
						assert(init_nodes.size()==1);
117
						init_nodes.swap(nodes);
117
						init_nodes.swap(nodes);
118
						is_built=false;
118
						is_built=false;
119
				}
119
				}
120
				init_nodes.push_back(KDNode(key,val));
120
				init_nodes.push_back(KDNode(key,val));
121
		}
121
		}
122
 
122
 
123
		/** Build the tree. After this function has been called, it is no longer 
123
		/** Build the tree. After this function has been called, it is no longer 
124
				legal to insert elements, but you can perform searches. */
124
				legal to insert elements, but you can perform searches. */
125
		void build()
125
		void build()
126
		{
126
		{
127
			assert(!is_built);
127
			assert(!is_built);
128
			nodes.resize(init_nodes.size());
128
			nodes.resize(init_nodes.size());
129
			if(init_nodes.size() > 1)	
129
			if(init_nodes.size() > 1)	
130
				optimize(1,1,init_nodes.size());
130
				optimize(1,1,init_nodes.size());
131
			NodeVecType v(1);
131
			NodeVecType v(1);
132
			init_nodes.swap(v);
132
			init_nodes.swap(v);
133
			is_built = true;
133
			is_built = true;
134
		}
134
		}
135
 
135
 
136
		/** Find the key value pair closest to the key given as first 
136
		/** Find the key value pair closest to the key given as first 
137
				argument. The second argument is the maximum search distance. Upon
137
				argument. The second argument is the maximum search distance. Upon
138
				return this value is changed to the distance to the found point.
138
				return this value is changed to the distance to the found point.
139
				The final two arguments contain the closest key and its 
139
				The final two arguments contain the closest key and its 
140
				associated value upon return. */
140
				associated value upon return. */
141
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
141
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
142
		{
142
		{
143
			assert(is_built);
143
			assert(is_built);
144
			if(nodes.size()>1)
144
			if(nodes.size()>1)
145
			{
145
			{
146
					ScalarType max_sq_dist = CGLA::sqr(dist);
146
					ScalarType max_sq_dist = CGLA::sqr(dist);
147
					if(int n = closest_point_priv(1, p, max_sq_dist))
147
					if(int n = closest_point_priv(1, p, max_sq_dist))
148
					{
148
					{
149
							k = nodes[n].key;
149
							k = nodes[n].key;
150
							v = nodes[n].val;
150
							v = nodes[n].val;
151
							dist = std::sqrt(max_sq_dist);
151
							dist = std::sqrt(max_sq_dist);
152
							return true;
152
							return true;
153
					}
153
					}
154
			}
154
			}
155
			return false;
155
			return false;
156
		}
156
		}
157
 
157
 
158
		/** Find all the elements within a given radius (second argument) of
158
		/** Find all the elements within a given radius (second argument) of
159
				the key (first argument). The key value pairs inside the sphere are
159
				the key (first argument). The key value pairs inside the sphere are
160
				returned in a pair of vectors passed as the two last arguments.
160
				returned in a pair of vectors passed as the two last arguments.
161
				Note that we don't resize the two last arguments to zero - so either
161
				Note that we don't resize the two last arguments to zero - so either
162
				they should be empty vectors or you should desire appending the newly
162
				they should be empty vectors or you should desire appending the newly
163
				found elements onto these vectors.				
163
				found elements onto these vectors.				
164
		*/
164
		*/
165
		int in_sphere(const KeyType& p, 
165
		int in_sphere(const KeyType& p, 
166
									ScalarType dist,
166
									ScalarType dist,
167
									std::vector<KeyT>& keys,
167
									std::vector<KeyT>& keys,
168
									std::vector<ValT>& vals) const
168
									std::vector<ValT>& vals) const
169
		{
169
		{
170
			assert(is_built);
170
			assert(is_built);
171
			if(nodes.size()>1)
171
			if(nodes.size()>1)
172
			{
172
			{
173
					ScalarType max_sq_dist = CGLA::sqr(dist);
173
					ScalarType max_sq_dist = CGLA::sqr(dist);
174
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
174
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
175
					return keys.size();
175
					return keys.size();
176
			}
176
			}
177
			return 0;
177
			return 0;
178
		}
178
		}
179
		
179
		
180
 
180
 
181
	};
181
	};
182
 
182
 
183
	template<class KeyT, class ValT>
183
	template<class KeyT, class ValT>
184
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
184
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
185
																	int kvec_end) const 
185
																	int kvec_end) const 
186
	{
186
	{
187
		KeyType vmin = init_nodes[kvec_beg].key;
187
		KeyType vmin = init_nodes[kvec_beg].key;
188
		KeyType vmax = init_nodes[kvec_beg].key;
188
		KeyType vmax = init_nodes[kvec_beg].key;
189
		for(int i=kvec_beg;i<kvec_end;i++)
189
		for(int i=kvec_beg;i<kvec_end;i++)
190
			{
190
			{
191
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
191
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
192
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
192
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
193
			}
193
			}
194
		int od=0;
194
		int od=0;
195
		KeyType ave_v = vmax-vmin;
195
		KeyType ave_v = vmax-vmin;
196
		for(int i=1;i<KeyType::get_dim();i++)
196
		for(int i=1;i<KeyType::get_dim();i++)
197
			if(ave_v[i]>ave_v[od]) od = i;
197
			if(ave_v[i]>ave_v[od]) od = i;
198
		return od;
198
		return od;
199
	} 
199
	} 
200
 
200
 
201
	template<class KeyT, class ValT>
201
	template<class KeyT, class ValT>
202
	void KDTree<KeyT,ValT>::optimize(int cur,
202
	void KDTree<KeyT,ValT>::optimize(int cur,
203
																	 int kvec_beg,  
203
																	 int kvec_beg,  
204
																	 int kvec_end)
204
																	 int kvec_end)
205
	{
205
	{
206
		// Assert that we are not inserting beyond capacity.
206
		// Assert that we are not inserting beyond capacity.
207
		assert(cur < nodes.size());
207
		assert(cur < nodes.size());
208
 
208
 
209
		// If there is just a single element, we simply insert.
209
		// If there is just a single element, we simply insert.
210
		if(kvec_beg+1==kvec_end) 
210
		if(kvec_beg+1==kvec_end) 
211
			{
211
			{
212
				nodes[cur] = init_nodes[kvec_beg];
212
				nodes[cur] = init_nodes[kvec_beg];
213
				nodes[cur].dsc = -1;
213
				nodes[cur].dsc = -1;
214
				return;
214
				return;
215
			}
215
			}
216
	
216
	
217
		// Find the axis that best separates the data.
217
		// Find the axis that best separates the data.
218
		int disc = opt_disc(kvec_beg, kvec_end);
218
		int disc = opt_disc(kvec_beg, kvec_end);
219
 
219
 
220
		// Compute the median element. See my document on how to do this
220
		// Compute the median element. See my document on how to do this
221
		// www.imm.dtu.dk/~jab/publications.html
221
		// www.imm.dtu.dk/~jab/publications.html
222
		int N = kvec_end-kvec_beg;
222
		int N = kvec_end-kvec_beg;
223
		int M = 1<< (CGLA::two_to_what_power(N));
223
		int M = 1<< (CGLA::two_to_what_power(N));
224
		int R = N-(M-1);
224
		int R = N-(M-1);
225
		int left_size  = (M-2)/2;
225
		int left_size  = (M-2)/2;
226
		int right_size = (M-2)/2;
226
		int right_size = (M-2)/2;
227
		if(R < M/2)
227
		if(R < M/2)
228
			{
228
			{
229
				left_size += R;
229
				left_size += R;
230
			}
230
			}
231
		else
231
		else
232
			{
232
			{
233
				left_size += M/2;
233
				left_size += M/2;
234
				right_size += R-M/2;
234
				right_size += R-M/2;
235
			}
235
			}
236
 
236
 
237
		int median = kvec_beg + left_size;
237
		int median = kvec_beg + left_size;
238
 
238
 
239
		// Sort elements but use nth_element (which is cheaper) than
239
		// Sort elements but use nth_element (which is cheaper) than
240
		// a sorting algorithm. All elements to the left of the median
240
		// a sorting algorithm. All elements to the left of the median
241
		// will be smaller than or equal the median. All elements to the right
241
		// will be smaller than or equal the median. All elements to the right
242
		// will be greater than or equal to the median.
242
		// will be greater than or equal to the median.
243
		const Comp comp(disc);
243
		const Comp comp(disc);
244
		std::nth_element(&init_nodes[kvec_beg], 
244
		std::nth_element(&init_nodes[kvec_beg], 
245
										 &init_nodes[median], 
245
										 &init_nodes[median], 
246
										 &init_nodes[kvec_end], comp);
246
										 &init_nodes[kvec_end], comp);
247
 
247
 
248
		// Insert the node in the final data structure.
248
		// Insert the node in the final data structure.
249
		nodes[cur] = init_nodes[median];
249
		nodes[cur] = init_nodes[median];
250
		nodes[cur].dsc = disc;
250
		nodes[cur].dsc = disc;
251
 
251
 
252
		// Recursively build left and right tree.
252
		// Recursively build left and right tree.
253
		if(left_size>0)	
253
		if(left_size>0)	
254
			optimize(2*cur, kvec_beg, median);
254
			optimize(2*cur, kvec_beg, median);
255
		
255
		
256
		if(right_size>0) 
256
		if(right_size>0) 
257
			optimize(2*cur+1, median+1, kvec_end);
257
			optimize(2*cur+1, median+1, kvec_end);
258
	}
258
	}
259
 
259
 
260
	template<class KeyT, class ValT>
260
	template<class KeyT, class ValT>
261
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
261
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
262
																						ScalarType& dist) const
262
																						ScalarType& dist) const
263
	{
263
	{
264
		int ret_node = 0;
264
		int ret_node = 0;
265
		ScalarType this_dist = nodes[n].dist(p);
265
		ScalarType this_dist = nodes[n].dist(p);
266
 
266
 
267
		if(this_dist<dist)
267
		if(this_dist<dist)
268
			{
268
			{
269
				dist = this_dist;
269
				dist = this_dist;
270
				ret_node = n;
270
				ret_node = n;
271
			}
271
			}
272
		if(nodes[n].dsc != -1)
272
		if(nodes[n].dsc != -1)
273
			{
273
			{
274
				int dsc         = nodes[n].dsc;
274
				int dsc         = nodes[n].dsc;
275
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
275
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
276
				bool left_son   = Comp(dsc)(p,nodes[n].key);
276
				bool left_son   = Comp(dsc)(p,nodes[n].key);
277
 
277
 
278
				if(left_son||dsc_dist<dist)
278
				if(left_son||dsc_dist<dist)
279
					{
279
					{
280
						int left_child = 2*n;
280
						int left_child = 2*n;
281
						if(left_child < nodes.size())
281
						if(left_child < nodes.size())
282
							if(int nl=closest_point_priv(left_child, p, dist))
282
							if(int nl=closest_point_priv(left_child, p, dist))
283
								ret_node = nl;
283
								ret_node = nl;
284
					}
284
					}
285
				if(!left_son||dsc_dist<dist)
285
				if(!left_son||dsc_dist<dist)
286
					{
286
					{
287
						int right_child = 2*n+1;
287
						int right_child = 2*n+1;
288
						if(right_child < nodes.size())
288
						if(right_child < nodes.size())
289
							if(int nr=closest_point_priv(right_child, p, dist))
289
							if(int nr=closest_point_priv(right_child, p, dist))
290
								ret_node = nr;
290
								ret_node = nr;
291
					}
291
					}
292
			}
292
			}
293
		return ret_node;
293
		return ret_node;
294
	}
294
	}
295
 
295
 
296
	template<class KeyT, class ValT>
296
	template<class KeyT, class ValT>
297
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
297
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
298
																				 const KeyType& p, 
298
																				 const KeyType& p, 
299
																				 const ScalarType& dist,
299
																				 const ScalarType& dist,
300
																				 std::vector<KeyT>& keys,
300
																				 std::vector<KeyT>& keys,
301
																				 std::vector<ValT>& vals) const
301
																				 std::vector<ValT>& vals) const
302
	{
302
	{
303
		ScalarType this_dist = nodes[n].dist(p);
303
		ScalarType this_dist = nodes[n].dist(p);
304
		assert(n<nodes.size());
304
		assert(n<nodes.size());
305
		if(this_dist<dist)
305
		if(this_dist<dist)
306
			{
306
			{
307
				keys.push_back(nodes[n].key);
307
				keys.push_back(nodes[n].key);
308
				vals.push_back(nodes[n].val);
308
				vals.push_back(nodes[n].val);
309
			}
309
			}
310
		if(nodes[n].dsc != -1)
310
		if(nodes[n].dsc != -1)
311
			{
311
			{
312
				const int dsc         = nodes[n].dsc;
312
				const int dsc         = nodes[n].dsc;
313
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
313
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
314
 
314
 
315
				bool left_son = Comp(dsc)(p,nodes[n].key);
315
				bool left_son = Comp(dsc)(p,nodes[n].key);
316
 
316
 
317
				if(left_son||dsc_dist<dist)
317
				if(left_son||dsc_dist<dist)
318
					{
318
					{
319
						int left_child = 2*n;
319
						int left_child = 2*n;
320
						if(left_child < nodes.size())
320
						if(left_child < nodes.size())
321
							in_sphere_priv(left_child, p, dist, keys, vals);
321
							in_sphere_priv(left_child, p, dist, keys, vals);
322
					}
322
					}
323
				if(!left_son||dsc_dist<dist)
323
				if(!left_son||dsc_dist<dist)
324
					{
324
					{
325
						int right_child = 2*n+1;
325
						int right_child = 2*n+1;
326
						if(right_child < nodes.size())
326
						if(right_child < nodes.size())
327
							in_sphere_priv(right_child, p, dist, keys, vals);
327
							in_sphere_priv(right_child, p, dist, keys, vals);
328
					}
328
					}
329
			}
329
			}
330
	}
330
	}
331
}
331
}
332
namespace GEO = Geometry;
332
namespace GEO = Geometry;
333
 
333
 
334
#if (_MSC_VER >= 1200)
334
#if (_MSC_VER >= 1200)
335
#pragma warning (pop)
335
#pragma warning (pop)
336
#endif
336
#endif
337
 
337
 
338
 
338
 
339
#endif
339
#endif
340
 
340