Subversion Repositories gelsvn

Rev

Rev 324 | Rev 334 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 324 Rev 330
1
#ifndef __KDTREE_H
1
#ifndef __KDTREE_H
2
#define __KDTREE_H
2
#define __KDTREE_H
3
 
3
 
4
#include <cmath>
4
#include <cmath>
5
#include <iostream>
5
#include <iostream>
6
#include <vector>
6
#include <vector>
7
#include <algorithm>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
8
#include "CGLA/CGLA.h"
9
#include "CGLA/ArithVec.h"
9
#include "CGLA/ArithVec.h"
10
 
10
 
11
#if (_MSC_VER >= 1200)
11
#if (_MSC_VER >= 1200)
12
#pragma warning (push)
12
#pragma warning (push)
13
#pragma warning (disable: 4018)
13
#pragma warning (disable: 4018)
14
#endif
14
#endif
15
 
15
 
16
namespace Geometry
16
namespace Geometry
17
{
17
{
18
	/** \brief A classic K-D tree. 
18
	/** \brief A classic K-D tree. 
19
 
19
 
20
			A K-D tree is a good data structure for storing points in space
20
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
21
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
22
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
23
	template<class KeyT, class ValT>
24
	class KDTree
24
	class KDTree
25
	{
25
	{
26
		typedef typename KeyT::ScalarType ScalarType;
26
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
27
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
28
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
29
		typedef std::vector<ValT> ValVectorType;
30
	
30
	
31
		/// KDNode struct represents node in KD tree
31
		/// KDNode struct represents node in KD tree
32
		struct KDNode
32
		struct KDNode
33
		{
33
		{
34
			KeyT key;
34
			KeyT key;
35
			ValT val;
35
			ValT val;
36
			short dsc;
36
			short dsc;
37
 
37
 
38
			KDNode(): dsc(0) {}
38
			KDNode(): dsc(0) {}
39
 
39
 
40
			KDNode(const KeyT& _key, const ValT& _val):
40
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
41
				key(_key), val(_val), dsc(-1) {}
42
 
42
 
43
			ScalarType dist(const KeyType& p) const 
43
			ScalarType dist(const KeyType& p) const 
44
			{
44
			{
45
				KeyType dist_vec = p;
45
				KeyType dist_vec = p;
46
				dist_vec  -= key;
46
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
47
				return dot(dist_vec, dist_vec);
48
			}
48
			}
49
		};
49
		};
50
 
50
 
51
		typedef std::vector<KDNode> NodeVecType;
51
		typedef std::vector<KDNode> NodeVecType;
-
 
52
		bool is_built;
52
		NodeVecType init_nodes;
53
		NodeVecType init_nodes;
53
		NodeVecType nodes;
54
		NodeVecType nodes;
54
		bool is_built;
-
 
55
 
-
 
56
		/// The greatest depth of KD tree.
-
 
57
		int max_depth;
-
 
58
 
-
 
59
		/// The dimension -- K
-
 
60
		const int DIM;
-
 
61
 
-
 
62
		/// Total number of elements in tree
-
 
63
		int elements;
-
 
64
 
55
 
65
		/** Comp is a class used for comparing two keys. Comp is constructed
56
		/** Comp is a class used for comparing two keys. Comp is constructed
66
				with the discriminator - i.e. the coordinate of the key that is used
57
				with the discriminator - i.e. the coordinate of the key that is used
67
				for comparing keys - Comp objects are passed to the sort algorithm.*/
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
68
		class Comp
59
		class Comp
69
		{
60
		{
70
			const int dsc;
61
			const int dsc;
71
		public:
62
		public:
72
			Comp(int _dsc): dsc(_dsc) {}
63
			Comp(int _dsc): dsc(_dsc) {}
73
			bool operator()(const KeyType& k0, const KeyType& k1) const
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
74
			{
65
			{
75
				int dim=KeyType::get_dim();
66
				int dim=KeyType::get_dim();
76
				for(int i=0;i<dim;i++)
67
				for(int i=0;i<dim;i++)
77
					{
68
					{
78
						int j=(dsc+i)%dim;
69
						int j=(dsc+i)%dim;
79
						if(k0[j]<k1[j])
70
						if(k0[j]<k1[j])
80
							return true;
71
							return true;
81
						if(k0[j]>k1[j])
72
						if(k0[j]>k1[j])
82
							return false;
73
							return false;
83
					}
74
					}
84
				return false;
75
				return false;
85
			}
76
			}
86
 
77
 
87
			bool operator()(const KDNode& k0, const KDNode& k1) const
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
88
			{
79
			{
89
				return (*this)(k0.key,k1.key);
80
				return (*this)(k0.key,k1.key);
90
			}
81
			}
91
		};
82
		};
92
 
83
 
93
 
84
 
94
		/** Passed a vector of keys, this function will construct an optimal tree.
85
		/** Passed a vector of keys, this function will construct an optimal tree.
95
				It is called recursively - second argument is level in tree. */
86
				It is called recursively */
96
		void optimize(int, int, int, int);
87
		void optimize(int, int, int);
97
 
88
 
98
		/** Finde nearest neighbour. */
89
		/** Finde nearest neighbour. */
99
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
100
 
91
 
101
							
92
							
102
		void in_sphere_priv(int n, 
93
		void in_sphere_priv(int n, 
103
												const KeyType& p, 
94
												const KeyType& p, 
104
												const ScalarType& dist,
95
												const ScalarType& dist,
105
												std::vector<KeyT>& keys,
96
												std::vector<KeyT>& keys,
106
												std::vector<ValT>& vals) const;
97
												std::vector<ValT>& vals) const;
107
	
98
	
108
		/** Finds the optimal discriminator. There are more ways, but this 
99
		/** Finds the optimal discriminator. There are more ways, but this 
109
				function traverses the vector and finds out what dimension has
100
				function traverses the vector and finds out what dimension has
110
				the greatest difference between min and max element. That dimension
101
				the greatest difference between min and max element. That dimension
111
				is used for discriminator */
102
				is used for discriminator */
112
		int opt_disc(int,int) const;
103
		int opt_disc(int,int) const;
113
	
104
	
114
	public:
105
	public:
115
 
106
 
116
		/** Build tree from vector of keys passed as argument. */
107
		/** Build tree from vector of keys passed as argument. */
117
		KDTree():
-
 
118
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
108
		KDTree(): is_built(false), init_nodes(1) {}
119
		{
-
 
120
		}
-
 
121
 
109
 
122
		/** Insert a key value pair into the tree. Note that the tree needs to 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
123
				be built - by calling the build function - before you can search. */
111
				be built - by calling the build function - before you can search. */
124
		void insert(const KeyT& key, const ValT& val)
112
		void insert(const KeyT& key, const ValT& val)
125
		{
113
		{
126
			assert(!is_built);
114
				if(is_built)
-
 
115
				{
-
 
116
						assert(init_nodes.size()==1);
-
 
117
						init_nodes.swap(nodes);
-
 
118
						is_built=false;
-
 
119
				}
127
			init_nodes.push_back(KDNode(key,val));
120
				init_nodes.push_back(KDNode(key,val));
128
		}
121
		}
129
 
122
 
130
		/** Build the tree. After this function have been called, it is no longer 
123
		/** Build the tree. After this function has been called, it is no longer 
131
				legal to insert elements, but you can perform searches. */
124
				legal to insert elements, but you can perform searches. */
132
		void build()
125
		void build()
133
		{
126
		{
134
			assert(!is_built);
127
			assert(!is_built);
135
			nodes.resize(init_nodes.size()+1);
128
			nodes.resize(init_nodes.size());
136
			if(init_nodes.size() > 0)	
129
			if(init_nodes.size() > 1)	
137
				optimize(1,0,init_nodes.size(),0);
130
				optimize(1,1,init_nodes.size());
138
			NodeVecType v(0);
131
			NodeVecType v(1);
139
			init_nodes.swap(v);
132
			init_nodes.swap(v);
140
			is_built = true;
133
			is_built = true;
141
		}
134
		}
142
 
135
 
143
		/** Find the key value pair closest to the key given as first 
136
		/** Find the key value pair closest to the key given as first 
144
				argument. The second argument is the maximum search distance.
137
				argument. The second argument is the maximum search distance.
145
				The final two arguments contain the closest key and its 
138
				The final two arguments contain the closest key and its 
146
				associated value upon return. */
139
				associated value upon return. */
147
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
140
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
148
		{
141
		{
149
			assert(is_built);
142
			assert(is_built);
-
 
143
			if(nodes.size()>1)
-
 
144
			{
150
			ScalarType max_sq_dist = CGLA::sqr(dist);
145
					ScalarType max_sq_dist = CGLA::sqr(dist);
151
			if(int n = closest_point_priv(1, p, max_sq_dist))
146
					if(int n = closest_point_priv(1, p, max_sq_dist))
152
				{
147
					{
153
					k = nodes[n].key;
148
							k = nodes[n].key;
154
					v = nodes[n].val;
149
							v = nodes[n].val;
155
					dist = std::sqrt(max_sq_dist);
150
							dist = std::sqrt(max_sq_dist);
156
					return true;
151
							return true;
157
				}
152
					}
-
 
153
			}
158
			return false;
154
			return false;
159
		}
155
		}
160
 
156
 
161
		/** Find all the elements within a given radius (second argument) of
157
		/** Find all the elements within a given radius (second argument) of
162
				the key (first argument). The key value pairs inside the sphere are
158
				the key (first argument). The key value pairs inside the sphere are
163
				returned in a pair of vectors passed as the two last arguments. */
159
				returned in a pair of vectors passed as the two last arguments.
-
 
160
				Note that we don't resize the two last arguments to zero - so either
-
 
161
				they should be empty vectors or you should desire appending the newly
-
 
162
				found elements onto these vectors.				
-
 
163
		*/
164
		int in_sphere(const KeyType& p, 
164
		int in_sphere(const KeyType& p, 
165
									ScalarType dist,
165
									ScalarType dist,
166
									std::vector<KeyT>& keys,
166
									std::vector<KeyT>& keys,
167
									std::vector<ValT>& vals) const
167
									std::vector<ValT>& vals) const
168
		{
168
		{
169
			assert(is_built);
169
			assert(is_built);
-
 
170
			if(nodes.size()>1)
-
 
171
			{
170
			ScalarType max_sq_dist = CGLA::sqr(dist);
172
					ScalarType max_sq_dist = CGLA::sqr(dist);
171
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
173
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
172
			return keys.size();
174
					return keys.size();
-
 
175
			}
-
 
176
			return 0;
173
		}
177
		}
174
		
178
		
175
 
179
 
176
	};
180
	};
177
 
181
 
178
	template<class KeyT, class ValT>
182
	template<class KeyT, class ValT>
179
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
183
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
180
																	int kvec_end) const 
184
																	int kvec_end) const 
181
	{
185
	{
182
		KeyType vmin = init_nodes[kvec_beg].key;
186
		KeyType vmin = init_nodes[kvec_beg].key;
183
		KeyType vmax = init_nodes[kvec_beg].key;
187
		KeyType vmax = init_nodes[kvec_beg].key;
184
		for(int i=kvec_beg;i<kvec_end;i++)
188
		for(int i=kvec_beg;i<kvec_end;i++)
185
			{
189
			{
186
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
190
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
187
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
191
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
188
			}
192
			}
189
		int od=0;
193
		int od=0;
190
		KeyType ave_v = vmax-vmin;
194
		KeyType ave_v = vmax-vmin;
191
		for(int i=1;i<KeyType::get_dim();i++)
195
		for(int i=1;i<KeyType::get_dim();i++)
192
			if(ave_v[i]>ave_v[od]) od = i;
196
			if(ave_v[i]>ave_v[od]) od = i;
193
		return od;
197
		return od;
194
	} 
198
	} 
195
 
199
 
196
	template<class KeyT, class ValT>
200
	template<class KeyT, class ValT>
197
	void KDTree<KeyT,ValT>::optimize(int cur,
201
	void KDTree<KeyT,ValT>::optimize(int cur,
198
																	 int kvec_beg,  
202
																	 int kvec_beg,  
199
																	 int kvec_end,  
203
																	 int kvec_end)
200
																	 int level)
-
 
201
	{
204
	{
202
		// Assert that we are not inserting beyond capacity.
205
		// Assert that we are not inserting beyond capacity.
203
		assert(cur < nodes.size());
206
		assert(cur < nodes.size());
204
 
207
 
205
		// If there is just a single element, we simply insert.
208
		// If there is just a single element, we simply insert.
206
		if(kvec_beg+1==kvec_end) 
209
		if(kvec_beg+1==kvec_end) 
207
			{
210
			{
208
				max_depth  = std::max(level,max_depth);
-
 
209
				nodes[cur] = init_nodes[kvec_beg];
211
				nodes[cur] = init_nodes[kvec_beg];
210
				nodes[cur].dsc = -1;
212
				nodes[cur].dsc = -1;
211
				return;
213
				return;
212
			}
214
			}
213
	
215
	
214
		// Find the axis that best separates the data.
216
		// Find the axis that best separates the data.
215
		int disc = opt_disc(kvec_beg, kvec_end);
217
		int disc = opt_disc(kvec_beg, kvec_end);
216
 
218
 
217
		// Compute the median element. See my document on how to do this
219
		// Compute the median element. See my document on how to do this
218
		// www.imm.dtu.dk/~jab/publications.html
220
		// www.imm.dtu.dk/~jab/publications.html
219
		int N = kvec_end-kvec_beg;
221
		int N = kvec_end-kvec_beg;
220
		int M = 1<< (CGLA::two_to_what_power(N));
222
		int M = 1<< (CGLA::two_to_what_power(N));
221
		int R = N-(M-1);
223
		int R = N-(M-1);
222
		int left_size  = (M-2)/2;
224
		int left_size  = (M-2)/2;
223
		int right_size = (M-2)/2;
225
		int right_size = (M-2)/2;
224
		if(R < M/2)
226
		if(R < M/2)
225
			{
227
			{
226
				left_size += R;
228
				left_size += R;
227
			}
229
			}
228
		else
230
		else
229
			{
231
			{
230
				left_size += M/2;
232
				left_size += M/2;
231
				right_size += R-M/2;
233
				right_size += R-M/2;
232
			}
234
			}
233
 
235
 
234
		int median = kvec_beg + left_size;
236
		int median = kvec_beg + left_size;
235
 
237
 
236
		// Sort elements but use nth_element (which is cheaper) than
238
		// Sort elements but use nth_element (which is cheaper) than
237
		// a sorting algorithm. All elements to the left of the median
239
		// a sorting algorithm. All elements to the left of the median
238
		// will be smaller than or equal the median. All elements to the right
240
		// will be smaller than or equal the median. All elements to the right
239
		// will be greater than or equal to the median.
241
		// will be greater than or equal to the median.
240
		const Comp comp(disc);
242
		const Comp comp(disc);
241
		std::nth_element(&init_nodes[kvec_beg], 
243
		std::nth_element(&init_nodes[kvec_beg], 
242
										 &init_nodes[median], 
244
										 &init_nodes[median], 
243
										 &init_nodes[kvec_end], comp);
245
										 &init_nodes[kvec_end], comp);
244
 
246
 
245
		// Insert the node in the final data structure.
247
		// Insert the node in the final data structure.
246
		nodes[cur] = init_nodes[median];
248
		nodes[cur] = init_nodes[median];
247
		nodes[cur].dsc = disc;
249
		nodes[cur].dsc = disc;
248
 
250
 
249
		// Recursively build left and right tree.
251
		// Recursively build left and right tree.
250
		if(left_size>0)	
252
		if(left_size>0)	
251
			optimize(2*cur, kvec_beg, median,level+1);
253
			optimize(2*cur, kvec_beg, median);
252
		
254
		
253
		if(right_size>0) 
255
		if(right_size>0) 
254
			optimize(2*cur+1, median+1, kvec_end,level+1);
256
			optimize(2*cur+1, median+1, kvec_end);
255
	}
257
	}
256
 
258
 
257
	template<class KeyT, class ValT>
259
	template<class KeyT, class ValT>
258
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
260
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
259
																						ScalarType& dist) const
261
																						ScalarType& dist) const
260
	{
262
	{
261
		int ret_node = 0;
263
		int ret_node = 0;
262
		ScalarType this_dist = nodes[n].dist(p);
264
		ScalarType this_dist = nodes[n].dist(p);
263
 
265
 
264
		if(this_dist<dist)
266
		if(this_dist<dist)
265
			{
267
			{
266
				dist = this_dist;
268
				dist = this_dist;
267
				ret_node = n;
269
				ret_node = n;
268
			}
270
			}
269
		if(nodes[n].dsc != -1)
271
		if(nodes[n].dsc != -1)
270
			{
272
			{
271
				int dsc         = nodes[n].dsc;
273
				int dsc         = nodes[n].dsc;
272
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
274
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
273
				bool left_son   = Comp(dsc)(p,nodes[n].key);
275
				bool left_son   = Comp(dsc)(p,nodes[n].key);
274
 
276
 
275
				if(left_son||dsc_dist<dist)
277
				if(left_son||dsc_dist<dist)
276
					{
278
					{
277
						int left_child = 2*n;
279
						int left_child = 2*n;
278
						if(left_child < nodes.size())
280
						if(left_child < nodes.size())
279
							if(int nl=closest_point_priv(left_child, p, dist))
281
							if(int nl=closest_point_priv(left_child, p, dist))
280
								ret_node = nl;
282
								ret_node = nl;
281
					}
283
					}
282
				if(!left_son||dsc_dist<dist)
284
				if(!left_son||dsc_dist<dist)
283
					{
285
					{
284
						int right_child = 2*n+1;
286
						int right_child = 2*n+1;
285
						if(right_child < nodes.size())
287
						if(right_child < nodes.size())
286
							if(int nr=closest_point_priv(right_child, p, dist))
288
							if(int nr=closest_point_priv(right_child, p, dist))
287
								ret_node = nr;
289
								ret_node = nr;
288
					}
290
					}
289
			}
291
			}
290
		return ret_node;
292
		return ret_node;
291
	}
293
	}
292
 
294
 
293
	template<class KeyT, class ValT>
295
	template<class KeyT, class ValT>
294
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
296
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
295
																				 const KeyType& p, 
297
																				 const KeyType& p, 
296
																				 const ScalarType& dist,
298
																				 const ScalarType& dist,
297
																				 std::vector<KeyT>& keys,
299
																				 std::vector<KeyT>& keys,
298
																				 std::vector<ValT>& vals) const
300
																				 std::vector<ValT>& vals) const
299
	{
301
	{
300
		ScalarType this_dist = nodes[n].dist(p);
302
		ScalarType this_dist = nodes[n].dist(p);
301
		assert(n<nodes.size());
303
		assert(n<nodes.size());
302
		if(this_dist<dist)
304
		if(this_dist<dist)
303
			{
305
			{
304
				keys.push_back(nodes[n].key);
306
				keys.push_back(nodes[n].key);
305
				vals.push_back(nodes[n].val);
307
				vals.push_back(nodes[n].val);
306
			}
308
			}
307
		if(nodes[n].dsc != -1)
309
		if(nodes[n].dsc != -1)
308
			{
310
			{
309
				const int dsc         = nodes[n].dsc;
311
				const int dsc         = nodes[n].dsc;
310
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
312
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
311
 
313
 
312
				bool left_son = Comp(dsc)(p,nodes[n].key);
314
				bool left_son = Comp(dsc)(p,nodes[n].key);
313
 
315
 
314
				if(left_son||dsc_dist<dist)
316
				if(left_son||dsc_dist<dist)
315
					{
317
					{
316
						int left_child = 2*n;
318
						int left_child = 2*n;
317
						if(left_child < nodes.size())
319
						if(left_child < nodes.size())
318
							in_sphere_priv(left_child, p, dist, keys, vals);
320
							in_sphere_priv(left_child, p, dist, keys, vals);
319
					}
321
					}
320
				if(!left_son||dsc_dist<dist)
322
				if(!left_son||dsc_dist<dist)
321
					{
323
					{
322
						int right_child = 2*n+1;
324
						int right_child = 2*n+1;
323
						if(right_child < nodes.size())
325
						if(right_child < nodes.size())
324
							in_sphere_priv(right_child, p, dist, keys, vals);
326
							in_sphere_priv(right_child, p, dist, keys, vals);
325
					}
327
					}
326
			}
328
			}
327
	}
329
	}
328
}
330
}
329
namespace GEO = Geometry;
331
namespace GEO = Geometry;
330
 
332
 
331
#if (_MSC_VER >= 1200)
333
#if (_MSC_VER >= 1200)
332
#pragma warning (pop)
334
#pragma warning (pop)
333
#endif
335
#endif
334
 
336
 
335
 
337
 
336
#endif
338
#endif
337
 
339