Subversion Repositories gelsvn

Rev

Rev 330 | Rev 443 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 330 Rev 334
1
#ifndef __KDTREE_H
1
#ifndef __KDTREE_H
2
#define __KDTREE_H
2
#define __KDTREE_H
3
 
3
 
4
#include <cmath>
4
#include <cmath>
5
#include <iostream>
5
#include <iostream>
6
#include <vector>
6
#include <vector>
7
#include <algorithm>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
8
#include "CGLA/CGLA.h"
9
#include "CGLA/ArithVec.h"
9
#include "CGLA/ArithVec.h"
10
 
10
 
11
#if (_MSC_VER >= 1200)
11
#if (_MSC_VER >= 1200)
12
#pragma warning (push)
12
#pragma warning (push)
13
#pragma warning (disable: 4018)
13
#pragma warning (disable: 4018)
14
#endif
14
#endif
15
 
15
 
16
namespace Geometry
16
namespace Geometry
17
{
17
{
18
	/** \brief A classic K-D tree. 
18
	/** \brief A classic K-D tree. 
19
 
19
 
20
			A K-D tree is a good data structure for storing points in space
20
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
21
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
22
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
23
	template<class KeyT, class ValT>
24
	class KDTree
24
	class KDTree
25
	{
25
	{
26
		typedef typename KeyT::ScalarType ScalarType;
26
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
27
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
28
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
29
		typedef std::vector<ValT> ValVectorType;
30
	
30
	
31
		/// KDNode struct represents node in KD tree
31
		/// KDNode struct represents node in KD tree
32
		struct KDNode
32
		struct KDNode
33
		{
33
		{
34
			KeyT key;
34
			KeyT key;
35
			ValT val;
35
			ValT val;
36
			short dsc;
36
			short dsc;
37
 
37
 
38
			KDNode(): dsc(0) {}
38
			KDNode(): dsc(0) {}
39
 
39
 
40
			KDNode(const KeyT& _key, const ValT& _val):
40
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
41
				key(_key), val(_val), dsc(-1) {}
42
 
42
 
43
			ScalarType dist(const KeyType& p) const 
43
			ScalarType dist(const KeyType& p) const 
44
			{
44
			{
45
				KeyType dist_vec = p;
45
				KeyType dist_vec = p;
46
				dist_vec  -= key;
46
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
47
				return dot(dist_vec, dist_vec);
48
			}
48
			}
49
		};
49
		};
50
 
50
 
51
		typedef std::vector<KDNode> NodeVecType;
51
		typedef std::vector<KDNode> NodeVecType;
52
		bool is_built;
52
		bool is_built;
53
		NodeVecType init_nodes;
53
		NodeVecType init_nodes;
54
		NodeVecType nodes;
54
		NodeVecType nodes;
55
 
55
 
56
		/** Comp is a class used for comparing two keys. Comp is constructed
56
		/** Comp is a class used for comparing two keys. Comp is constructed
57
				with the discriminator - i.e. the coordinate of the key that is used
57
				with the discriminator - i.e. the coordinate of the key that is used
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
		class Comp
59
		class Comp
60
		{
60
		{
61
			const int dsc;
61
			const int dsc;
62
		public:
62
		public:
63
			Comp(int _dsc): dsc(_dsc) {}
63
			Comp(int _dsc): dsc(_dsc) {}
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			{
65
			{
66
				int dim=KeyType::get_dim();
66
				int dim=KeyType::get_dim();
67
				for(int i=0;i<dim;i++)
67
				for(int i=0;i<dim;i++)
68
					{
68
					{
69
						int j=(dsc+i)%dim;
69
						int j=(dsc+i)%dim;
70
						if(k0[j]<k1[j])
70
						if(k0[j]<k1[j])
71
							return true;
71
							return true;
72
						if(k0[j]>k1[j])
72
						if(k0[j]>k1[j])
73
							return false;
73
							return false;
74
					}
74
					}
75
				return false;
75
				return false;
76
			}
76
			}
77
 
77
 
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			{
79
			{
80
				return (*this)(k0.key,k1.key);
80
				return (*this)(k0.key,k1.key);
81
			}
81
			}
82
		};
82
		};
83
 
83
 
84
 
84
 
85
		/** Passed a vector of keys, this function will construct an optimal tree.
85
		/** Passed a vector of keys, this function will construct an optimal tree.
86
				It is called recursively */
86
				It is called recursively */
87
		void optimize(int, int, int);
87
		void optimize(int, int, int);
88
 
88
 
89
		/** Finde nearest neighbour. */
89
		/** Finde nearest neighbour. */
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
91
 
91
 
92
							
92
							
93
		void in_sphere_priv(int n, 
93
		void in_sphere_priv(int n, 
94
												const KeyType& p, 
94
												const KeyType& p, 
95
												const ScalarType& dist,
95
												const ScalarType& dist,
96
												std::vector<KeyT>& keys,
96
												std::vector<KeyT>& keys,
97
												std::vector<ValT>& vals) const;
97
												std::vector<ValT>& vals) const;
98
	
98
	
99
		/** Finds the optimal discriminator. There are more ways, but this 
99
		/** Finds the optimal discriminator. There are more ways, but this 
100
				function traverses the vector and finds out what dimension has
100
				function traverses the vector and finds out what dimension has
101
				the greatest difference between min and max element. That dimension
101
				the greatest difference between min and max element. That dimension
102
				is used for discriminator */
102
				is used for discriminator */
103
		int opt_disc(int,int) const;
103
		int opt_disc(int,int) const;
104
	
104
	
105
	public:
105
	public:
106
 
106
 
107
		/** Build tree from vector of keys passed as argument. */
107
		/** Build tree from vector of keys passed as argument. */
108
		KDTree(): is_built(false), init_nodes(1) {}
108
		KDTree(): is_built(false), init_nodes(1) {}
109
 
109
 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
111
				be built - by calling the build function - before you can search. */
111
				be built - by calling the build function - before you can search. */
112
		void insert(const KeyT& key, const ValT& val)
112
		void insert(const KeyT& key, const ValT& val)
113
		{
113
		{
114
				if(is_built)
114
				if(is_built)
115
				{
115
				{
116
						assert(init_nodes.size()==1);
116
						assert(init_nodes.size()==1);
117
						init_nodes.swap(nodes);
117
						init_nodes.swap(nodes);
118
						is_built=false;
118
						is_built=false;
119
				}
119
				}
120
				init_nodes.push_back(KDNode(key,val));
120
				init_nodes.push_back(KDNode(key,val));
121
		}
121
		}
122
 
122
 
123
		/** Build the tree. After this function has been called, it is no longer 
123
		/** Build the tree. After this function has been called, it is no longer 
124
				legal to insert elements, but you can perform searches. */
124
				legal to insert elements, but you can perform searches. */
125
		void build()
125
		void build()
126
		{
126
		{
127
			assert(!is_built);
127
			assert(!is_built);
128
			nodes.resize(init_nodes.size());
128
			nodes.resize(init_nodes.size());
129
			if(init_nodes.size() > 1)	
129
			if(init_nodes.size() > 1)	
130
				optimize(1,1,init_nodes.size());
130
				optimize(1,1,init_nodes.size());
131
			NodeVecType v(1);
131
			NodeVecType v(1);
132
			init_nodes.swap(v);
132
			init_nodes.swap(v);
133
			is_built = true;
133
			is_built = true;
134
		}
134
		}
135
 
135
 
136
		/** Find the key value pair closest to the key given as first 
136
		/** Find the key value pair closest to the key given as first 
137
				argument. The second argument is the maximum search distance.
137
				argument. The second argument is the maximum search distance. Upon
-
 
138
				return this value is changed to the distance to the found point.
138
				The final two arguments contain the closest key and its 
139
				The final two arguments contain the closest key and its 
139
				associated value upon return. */
140
				associated value upon return. */
140
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
141
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
141
		{
142
		{
142
			assert(is_built);
143
			assert(is_built);
143
			if(nodes.size()>1)
144
			if(nodes.size()>1)
144
			{
145
			{
145
					ScalarType max_sq_dist = CGLA::sqr(dist);
146
					ScalarType max_sq_dist = CGLA::sqr(dist);
146
					if(int n = closest_point_priv(1, p, max_sq_dist))
147
					if(int n = closest_point_priv(1, p, max_sq_dist))
147
					{
148
					{
148
							k = nodes[n].key;
149
							k = nodes[n].key;
149
							v = nodes[n].val;
150
							v = nodes[n].val;
150
							dist = std::sqrt(max_sq_dist);
151
							dist = std::sqrt(max_sq_dist);
151
							return true;
152
							return true;
152
					}
153
					}
153
			}
154
			}
154
			return false;
155
			return false;
155
		}
156
		}
156
 
157
 
157
		/** Find all the elements within a given radius (second argument) of
158
		/** Find all the elements within a given radius (second argument) of
158
				the key (first argument). The key value pairs inside the sphere are
159
				the key (first argument). The key value pairs inside the sphere are
159
				returned in a pair of vectors passed as the two last arguments.
160
				returned in a pair of vectors passed as the two last arguments.
160
				Note that we don't resize the two last arguments to zero - so either
161
				Note that we don't resize the two last arguments to zero - so either
161
				they should be empty vectors or you should desire appending the newly
162
				they should be empty vectors or you should desire appending the newly
162
				found elements onto these vectors.				
163
				found elements onto these vectors.				
163
		*/
164
		*/
164
		int in_sphere(const KeyType& p, 
165
		int in_sphere(const KeyType& p, 
165
									ScalarType dist,
166
									ScalarType dist,
166
									std::vector<KeyT>& keys,
167
									std::vector<KeyT>& keys,
167
									std::vector<ValT>& vals) const
168
									std::vector<ValT>& vals) const
168
		{
169
		{
169
			assert(is_built);
170
			assert(is_built);
170
			if(nodes.size()>1)
171
			if(nodes.size()>1)
171
			{
172
			{
172
					ScalarType max_sq_dist = CGLA::sqr(dist);
173
					ScalarType max_sq_dist = CGLA::sqr(dist);
173
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
174
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
174
					return keys.size();
175
					return keys.size();
175
			}
176
			}
176
			return 0;
177
			return 0;
177
		}
178
		}
178
		
179
		
179
 
180
 
180
	};
181
	};
181
 
182
 
182
	template<class KeyT, class ValT>
183
	template<class KeyT, class ValT>
183
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
184
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
184
																	int kvec_end) const 
185
																	int kvec_end) const 
185
	{
186
	{
186
		KeyType vmin = init_nodes[kvec_beg].key;
187
		KeyType vmin = init_nodes[kvec_beg].key;
187
		KeyType vmax = init_nodes[kvec_beg].key;
188
		KeyType vmax = init_nodes[kvec_beg].key;
188
		for(int i=kvec_beg;i<kvec_end;i++)
189
		for(int i=kvec_beg;i<kvec_end;i++)
189
			{
190
			{
190
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
191
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
191
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
192
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
192
			}
193
			}
193
		int od=0;
194
		int od=0;
194
		KeyType ave_v = vmax-vmin;
195
		KeyType ave_v = vmax-vmin;
195
		for(int i=1;i<KeyType::get_dim();i++)
196
		for(int i=1;i<KeyType::get_dim();i++)
196
			if(ave_v[i]>ave_v[od]) od = i;
197
			if(ave_v[i]>ave_v[od]) od = i;
197
		return od;
198
		return od;
198
	} 
199
	} 
199
 
200
 
200
	template<class KeyT, class ValT>
201
	template<class KeyT, class ValT>
201
	void KDTree<KeyT,ValT>::optimize(int cur,
202
	void KDTree<KeyT,ValT>::optimize(int cur,
202
																	 int kvec_beg,  
203
																	 int kvec_beg,  
203
																	 int kvec_end)
204
																	 int kvec_end)
204
	{
205
	{
205
		// Assert that we are not inserting beyond capacity.
206
		// Assert that we are not inserting beyond capacity.
206
		assert(cur < nodes.size());
207
		assert(cur < nodes.size());
207
 
208
 
208
		// If there is just a single element, we simply insert.
209
		// If there is just a single element, we simply insert.
209
		if(kvec_beg+1==kvec_end) 
210
		if(kvec_beg+1==kvec_end) 
210
			{
211
			{
211
				nodes[cur] = init_nodes[kvec_beg];
212
				nodes[cur] = init_nodes[kvec_beg];
212
				nodes[cur].dsc = -1;
213
				nodes[cur].dsc = -1;
213
				return;
214
				return;
214
			}
215
			}
215
	
216
	
216
		// Find the axis that best separates the data.
217
		// Find the axis that best separates the data.
217
		int disc = opt_disc(kvec_beg, kvec_end);
218
		int disc = opt_disc(kvec_beg, kvec_end);
218
 
219
 
219
		// Compute the median element. See my document on how to do this
220
		// Compute the median element. See my document on how to do this
220
		// www.imm.dtu.dk/~jab/publications.html
221
		// www.imm.dtu.dk/~jab/publications.html
221
		int N = kvec_end-kvec_beg;
222
		int N = kvec_end-kvec_beg;
222
		int M = 1<< (CGLA::two_to_what_power(N));
223
		int M = 1<< (CGLA::two_to_what_power(N));
223
		int R = N-(M-1);
224
		int R = N-(M-1);
224
		int left_size  = (M-2)/2;
225
		int left_size  = (M-2)/2;
225
		int right_size = (M-2)/2;
226
		int right_size = (M-2)/2;
226
		if(R < M/2)
227
		if(R < M/2)
227
			{
228
			{
228
				left_size += R;
229
				left_size += R;
229
			}
230
			}
230
		else
231
		else
231
			{
232
			{
232
				left_size += M/2;
233
				left_size += M/2;
233
				right_size += R-M/2;
234
				right_size += R-M/2;
234
			}
235
			}
235
 
236
 
236
		int median = kvec_beg + left_size;
237
		int median = kvec_beg + left_size;
237
 
238
 
238
		// Sort elements but use nth_element (which is cheaper) than
239
		// Sort elements but use nth_element (which is cheaper) than
239
		// a sorting algorithm. All elements to the left of the median
240
		// a sorting algorithm. All elements to the left of the median
240
		// will be smaller than or equal the median. All elements to the right
241
		// will be smaller than or equal the median. All elements to the right
241
		// will be greater than or equal to the median.
242
		// will be greater than or equal to the median.
242
		const Comp comp(disc);
243
		const Comp comp(disc);
243
		std::nth_element(&init_nodes[kvec_beg], 
244
		std::nth_element(&init_nodes[kvec_beg], 
244
										 &init_nodes[median], 
245
										 &init_nodes[median], 
245
										 &init_nodes[kvec_end], comp);
246
										 &init_nodes[kvec_end], comp);
246
 
247
 
247
		// Insert the node in the final data structure.
248
		// Insert the node in the final data structure.
248
		nodes[cur] = init_nodes[median];
249
		nodes[cur] = init_nodes[median];
249
		nodes[cur].dsc = disc;
250
		nodes[cur].dsc = disc;
250
 
251
 
251
		// Recursively build left and right tree.
252
		// Recursively build left and right tree.
252
		if(left_size>0)	
253
		if(left_size>0)	
253
			optimize(2*cur, kvec_beg, median);
254
			optimize(2*cur, kvec_beg, median);
254
		
255
		
255
		if(right_size>0) 
256
		if(right_size>0) 
256
			optimize(2*cur+1, median+1, kvec_end);
257
			optimize(2*cur+1, median+1, kvec_end);
257
	}
258
	}
258
 
259
 
259
	template<class KeyT, class ValT>
260
	template<class KeyT, class ValT>
260
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
261
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
261
																						ScalarType& dist) const
262
																						ScalarType& dist) const
262
	{
263
	{
263
		int ret_node = 0;
264
		int ret_node = 0;
264
		ScalarType this_dist = nodes[n].dist(p);
265
		ScalarType this_dist = nodes[n].dist(p);
265
 
266
 
266
		if(this_dist<dist)
267
		if(this_dist<dist)
267
			{
268
			{
268
				dist = this_dist;
269
				dist = this_dist;
269
				ret_node = n;
270
				ret_node = n;
270
			}
271
			}
271
		if(nodes[n].dsc != -1)
272
		if(nodes[n].dsc != -1)
272
			{
273
			{
273
				int dsc         = nodes[n].dsc;
274
				int dsc         = nodes[n].dsc;
274
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
275
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
275
				bool left_son   = Comp(dsc)(p,nodes[n].key);
276
				bool left_son   = Comp(dsc)(p,nodes[n].key);
276
 
277
 
277
				if(left_son||dsc_dist<dist)
278
				if(left_son||dsc_dist<dist)
278
					{
279
					{
279
						int left_child = 2*n;
280
						int left_child = 2*n;
280
						if(left_child < nodes.size())
281
						if(left_child < nodes.size())
281
							if(int nl=closest_point_priv(left_child, p, dist))
282
							if(int nl=closest_point_priv(left_child, p, dist))
282
								ret_node = nl;
283
								ret_node = nl;
283
					}
284
					}
284
				if(!left_son||dsc_dist<dist)
285
				if(!left_son||dsc_dist<dist)
285
					{
286
					{
286
						int right_child = 2*n+1;
287
						int right_child = 2*n+1;
287
						if(right_child < nodes.size())
288
						if(right_child < nodes.size())
288
							if(int nr=closest_point_priv(right_child, p, dist))
289
							if(int nr=closest_point_priv(right_child, p, dist))
289
								ret_node = nr;
290
								ret_node = nr;
290
					}
291
					}
291
			}
292
			}
292
		return ret_node;
293
		return ret_node;
293
	}
294
	}
294
 
295
 
295
	template<class KeyT, class ValT>
296
	template<class KeyT, class ValT>
296
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
297
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
297
																				 const KeyType& p, 
298
																				 const KeyType& p, 
298
																				 const ScalarType& dist,
299
																				 const ScalarType& dist,
299
																				 std::vector<KeyT>& keys,
300
																				 std::vector<KeyT>& keys,
300
																				 std::vector<ValT>& vals) const
301
																				 std::vector<ValT>& vals) const
301
	{
302
	{
302
		ScalarType this_dist = nodes[n].dist(p);
303
		ScalarType this_dist = nodes[n].dist(p);
303
		assert(n<nodes.size());
304
		assert(n<nodes.size());
304
		if(this_dist<dist)
305
		if(this_dist<dist)
305
			{
306
			{
306
				keys.push_back(nodes[n].key);
307
				keys.push_back(nodes[n].key);
307
				vals.push_back(nodes[n].val);
308
				vals.push_back(nodes[n].val);
308
			}
309
			}
309
		if(nodes[n].dsc != -1)
310
		if(nodes[n].dsc != -1)
310
			{
311
			{
311
				const int dsc         = nodes[n].dsc;
312
				const int dsc         = nodes[n].dsc;
312
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
313
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
313
 
314
 
314
				bool left_son = Comp(dsc)(p,nodes[n].key);
315
				bool left_son = Comp(dsc)(p,nodes[n].key);
315
 
316
 
316
				if(left_son||dsc_dist<dist)
317
				if(left_son||dsc_dist<dist)
317
					{
318
					{
318
						int left_child = 2*n;
319
						int left_child = 2*n;
319
						if(left_child < nodes.size())
320
						if(left_child < nodes.size())
320
							in_sphere_priv(left_child, p, dist, keys, vals);
321
							in_sphere_priv(left_child, p, dist, keys, vals);
321
					}
322
					}
322
				if(!left_son||dsc_dist<dist)
323
				if(!left_son||dsc_dist<dist)
323
					{
324
					{
324
						int right_child = 2*n+1;
325
						int right_child = 2*n+1;
325
						if(right_child < nodes.size())
326
						if(right_child < nodes.size())
326
							in_sphere_priv(right_child, p, dist, keys, vals);
327
							in_sphere_priv(right_child, p, dist, keys, vals);
327
					}
328
					}
328
			}
329
			}
329
	}
330
	}
330
}
331
}
331
namespace GEO = Geometry;
332
namespace GEO = Geometry;
332
 
333
 
333
#if (_MSC_VER >= 1200)
334
#if (_MSC_VER >= 1200)
334
#pragma warning (pop)
335
#pragma warning (pop)
335
#endif
336
#endif
336
 
337
 
337
 
338
 
338
#endif
339
#endif
339
 
340