Subversion Repositories gelsvn

Rev

Rev 203 | Rev 330 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 203 Rev 324
1
#ifndef __KDTREE_H
1
#ifndef __KDTREE_H
2
#define __KDTREE_H
2
#define __KDTREE_H
3
 
3
 
4
#include <cmath>
4
#include <cmath>
5
#include <iostream>
5
#include <iostream>
6
#include <vector>
6
#include <vector>
7
#include <algorithm>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
8
#include "CGLA/CGLA.h"
9
#include "CGLA/ArithVec.h"
9
#include "CGLA/ArithVec.h"
10
 
10
 
11
#if (_MSC_VER >= 1200)
11
#if (_MSC_VER >= 1200)
12
#pragma warning (push)
12
#pragma warning (push)
13
#pragma warning (disable: 4018)
13
#pragma warning (disable: 4018)
14
#endif
14
#endif
15
 
15
 
16
namespace Geometry
16
namespace Geometry
17
{
17
{
18
	/** \brief A classic K-D tree. 
18
	/** \brief A classic K-D tree. 
19
 
19
 
20
			A K-D tree is a good data structure for storing points in space
20
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
21
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
22
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
23
	template<class KeyT, class ValT>
24
	class KDTree
24
	class KDTree
25
	{
25
	{
26
		typedef typename KeyT::ScalarType ScalarType;
26
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
27
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
28
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
29
		typedef std::vector<ValT> ValVectorType;
30
	
30
	
31
		/// KDNode struct represents node in KD tree
31
		/// KDNode struct represents node in KD tree
32
		struct KDNode
32
		struct KDNode
33
		{
33
		{
34
			KeyT key;
34
			KeyT key;
35
			ValT val;
35
			ValT val;
36
			short dsc;
36
			short dsc;
37
 
37
 
38
			KDNode(): dsc(0) {}
38
			KDNode(): dsc(0) {}
39
 
39
 
40
			KDNode(const KeyT& _key, const ValT& _val):
40
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
41
				key(_key), val(_val), dsc(-1) {}
42
 
42
 
43
			ScalarType dist(const KeyType& p) const 
43
			ScalarType dist(const KeyType& p) const 
44
			{
44
			{
45
				KeyType dist_vec = p;
45
				KeyType dist_vec = p;
46
				dist_vec  -= key;
46
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
47
				return dot(dist_vec, dist_vec);
48
			}
48
			}
49
		};
49
		};
50
 
50
 
51
		typedef std::vector<KDNode> NodeVecType;
51
		typedef std::vector<KDNode> NodeVecType;
52
		NodeVecType init_nodes;
52
		NodeVecType init_nodes;
53
		NodeVecType nodes;
53
		NodeVecType nodes;
54
		bool is_built;
54
		bool is_built;
55
 
55
 
56
		/// The greatest depth of KD tree.
56
		/// The greatest depth of KD tree.
57
		int max_depth;
57
		int max_depth;
58
 
58
 
-
 
59
		/// The dimension -- K
-
 
60
		const int DIM;
-
 
61
 
59
		/// Total number of elements in tree
62
		/// Total number of elements in tree
60
		int elements;
63
		int elements;
61
 
64
 
62
		/** Comp is a class used for comparing two keys. Comp is constructed
65
		/** Comp is a class used for comparing two keys. Comp is constructed
63
				with the discriminator - i.e. the coordinate of the key that is used
66
				with the discriminator - i.e. the coordinate of the key that is used
64
				for comparing keys - Comp objects are passed to the sort algorithm.*/
67
				for comparing keys - Comp objects are passed to the sort algorithm.*/
65
		class Comp
68
		class Comp
66
		{
69
		{
67
			const int dsc;
70
			const int dsc;
68
		public:
71
		public:
69
			Comp(int _dsc): dsc(_dsc) {}
72
			Comp(int _dsc): dsc(_dsc) {}
70
			bool operator()(const KeyType& k0, const KeyType& k1) const
73
			bool operator()(const KeyType& k0, const KeyType& k1) const
71
			{
74
			{
72
				int dim=KeyType::get_dim();
75
				int dim=KeyType::get_dim();
73
				for(int i=0;i<dim;i++)
76
				for(int i=0;i<dim;i++)
74
					{
77
					{
75
						int j=(dsc+i)%dim;
78
						int j=(dsc+i)%dim;
76
						if(k0[j]<k1[j])
79
						if(k0[j]<k1[j])
77
							return true;
80
							return true;
78
						if(k0[j]>k1[j])
81
						if(k0[j]>k1[j])
79
							return false;
82
							return false;
80
					}
83
					}
81
				return false;
84
				return false;
82
			}
85
			}
83
 
86
 
84
			bool operator()(const KDNode& k0, const KDNode& k1) const
87
			bool operator()(const KDNode& k0, const KDNode& k1) const
85
			{
88
			{
86
				return (*this)(k0.key,k1.key);
89
				return (*this)(k0.key,k1.key);
87
			}
90
			}
88
		};
91
		};
89
 
92
 
90
		/// The dimension -- K
-
 
91
		const int DIM;
-
 
92
 
93
 
93
		/** Passed a vector of keys, this function will construct an optimal tree.
94
		/** Passed a vector of keys, this function will construct an optimal tree.
94
				It is called recursively - second argument is level in tree. */
95
				It is called recursively - second argument is level in tree. */
95
		void optimize(int, int, int, int);
96
		void optimize(int, int, int, int);
96
 
97
 
97
		/** Finde nearest neighbour. */
98
		/** Finde nearest neighbour. */
98
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
99
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
99
 
100
 
100
							
101
							
101
		void in_sphere_priv(int n, 
102
		void in_sphere_priv(int n, 
102
												const KeyType& p, 
103
												const KeyType& p, 
103
												const ScalarType& dist,
104
												const ScalarType& dist,
104
												std::vector<KeyT>& keys,
105
												std::vector<KeyT>& keys,
105
												std::vector<ValT>& vals) const;
106
												std::vector<ValT>& vals) const;
106
	
107
	
107
		/** Finds the optimal discriminator. There are more ways, but this 
108
		/** Finds the optimal discriminator. There are more ways, but this 
108
				function traverses the vector and finds out what dimension has
109
				function traverses the vector and finds out what dimension has
109
				the greatest difference between min and max element. That dimension
110
				the greatest difference between min and max element. That dimension
110
				is used for discriminator */
111
				is used for discriminator */
111
		int opt_disc(int,int) const;
112
		int opt_disc(int,int) const;
112
	
113
	
113
	public:
114
	public:
114
 
115
 
115
		/** Build tree from vector of keys passed as argument. */
116
		/** Build tree from vector of keys passed as argument. */
116
		KDTree():
117
		KDTree():
117
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
118
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
118
		{
119
		{
119
		}
120
		}
120
 
121
 
121
		/** Insert a key value pair into the tree. Note that the tree needs to 
122
		/** Insert a key value pair into the tree. Note that the tree needs to 
122
				be built - by calling the build function - before you can search. */
123
				be built - by calling the build function - before you can search. */
123
		void insert(const KeyT& key, const ValT& val)
124
		void insert(const KeyT& key, const ValT& val)
124
		{
125
		{
125
			assert(!is_built);
126
			assert(!is_built);
126
			init_nodes.push_back(KDNode(key,val));
127
			init_nodes.push_back(KDNode(key,val));
127
		}
128
		}
128
 
129
 
129
		/** Build the tree. After this function have been called, it is no longer 
130
		/** Build the tree. After this function have been called, it is no longer 
130
				legal to insert elements, but you can perform searches. */
131
				legal to insert elements, but you can perform searches. */
131
		void build()
132
		void build()
132
		{
133
		{
133
			assert(!is_built);
134
			assert(!is_built);
134
			nodes.resize(init_nodes.size()+1);
135
			nodes.resize(init_nodes.size()+1);
135
			if(init_nodes.size() > 0)	
136
			if(init_nodes.size() > 0)	
136
				optimize(1,0,init_nodes.size(),0);
137
				optimize(1,0,init_nodes.size(),0);
137
			NodeVecType v(0);
138
			NodeVecType v(0);
138
			init_nodes.swap(v);
139
			init_nodes.swap(v);
139
			is_built = true;
140
			is_built = true;
140
		}
141
		}
141
 
142
 
142
		/** Find the key value pair closest to the key given as first 
143
		/** Find the key value pair closest to the key given as first 
143
				argument. The second argument is the maximum search distance.
144
				argument. The second argument is the maximum search distance.
144
				The final two arguments contain the closest key and its 
145
				The final two arguments contain the closest key and its 
145
				associated value upon return. */
146
				associated value upon return. */
146
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
147
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
147
		{
148
		{
148
			assert(is_built);
149
			assert(is_built);
149
			float max_sq_dist = CGLA::sqr(dist);
150
			ScalarType max_sq_dist = CGLA::sqr(dist);
150
			if(int n = closest_point_priv(1, p, max_sq_dist))
151
			if(int n = closest_point_priv(1, p, max_sq_dist))
151
				{
152
				{
152
					k = nodes[n].key;
153
					k = nodes[n].key;
153
					v = nodes[n].val;
154
					v = nodes[n].val;
154
					dist = std::sqrt(max_sq_dist);
155
					dist = std::sqrt(max_sq_dist);
155
					return true;
156
					return true;
156
				}
157
				}
157
			return false;
158
			return false;
158
		}
159
		}
159
 
160
 
160
		/** Find all the elements within a given radius (second argument) of
161
		/** Find all the elements within a given radius (second argument) of
161
				the key (first argument). The key value pairs inside the sphere are
162
				the key (first argument). The key value pairs inside the sphere are
162
				returned in a pair of vectors passed as the two last arguments. */
163
				returned in a pair of vectors passed as the two last arguments. */
163
		int in_sphere(const KeyType& p, 
164
		int in_sphere(const KeyType& p, 
164
									float dist,
165
									ScalarType dist,
165
									std::vector<KeyT>& keys,
166
									std::vector<KeyT>& keys,
166
									std::vector<ValT>& vals) const
167
									std::vector<ValT>& vals) const
167
		{
168
		{
168
			assert(is_built);
169
			assert(is_built);
169
			float max_sq_dist = CGLA::sqr(dist);
170
			ScalarType max_sq_dist = CGLA::sqr(dist);
170
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
171
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
171
			return keys.size();
172
			return keys.size();
172
		}
173
		}
173
		
174
		
174
 
175
 
175
	};
176
	};
176
 
177
 
177
	template<class KeyT, class ValT>
178
	template<class KeyT, class ValT>
178
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
179
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
179
																	int kvec_end) const 
180
																	int kvec_end) const 
180
	{
181
	{
181
		KeyType vmin = init_nodes[kvec_beg].key;
182
		KeyType vmin = init_nodes[kvec_beg].key;
182
		KeyType vmax = init_nodes[kvec_beg].key;
183
		KeyType vmax = init_nodes[kvec_beg].key;
183
		for(int i=kvec_beg;i<kvec_end;i++)
184
		for(int i=kvec_beg;i<kvec_end;i++)
184
			{
185
			{
185
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
186
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
186
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
187
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
187
			}
188
			}
188
		int od=0;
189
		int od=0;
189
		KeyType ave_v = vmax-vmin;
190
		KeyType ave_v = vmax-vmin;
190
		for(int i=1;i<KeyType::get_dim();i++)
191
		for(int i=1;i<KeyType::get_dim();i++)
191
			if(ave_v[i]>ave_v[od]) od = i;
192
			if(ave_v[i]>ave_v[od]) od = i;
192
		return od;
193
		return od;
193
	} 
194
	} 
194
 
195
 
195
	template<class KeyT, class ValT>
196
	template<class KeyT, class ValT>
196
	void KDTree<KeyT,ValT>::optimize(int cur,
197
	void KDTree<KeyT,ValT>::optimize(int cur,
197
																	 int kvec_beg,  
198
																	 int kvec_beg,  
198
																	 int kvec_end,  
199
																	 int kvec_end,  
199
																	 int level)
200
																	 int level)
200
	{
201
	{
201
		// Assert that we are not inserting beyond capacity.
202
		// Assert that we are not inserting beyond capacity.
202
		assert(cur < nodes.size());
203
		assert(cur < nodes.size());
203
 
204
 
204
		// If there is just a single element, we simply insert.
205
		// If there is just a single element, we simply insert.
205
		if(kvec_beg+1==kvec_end) 
206
		if(kvec_beg+1==kvec_end) 
206
			{
207
			{
207
				max_depth  = std::max(level,max_depth);
208
				max_depth  = std::max(level,max_depth);
208
				nodes[cur] = init_nodes[kvec_beg];
209
				nodes[cur] = init_nodes[kvec_beg];
209
				nodes[cur].dsc = -1;
210
				nodes[cur].dsc = -1;
210
				return;
211
				return;
211
			}
212
			}
212
	
213
	
213
		// Find the axis that best separates the data.
214
		// Find the axis that best separates the data.
214
		int disc = opt_disc(kvec_beg, kvec_end);
215
		int disc = opt_disc(kvec_beg, kvec_end);
215
 
216
 
216
		// Compute the median element. See my document on how to do this
217
		// Compute the median element. See my document on how to do this
217
		// www.imm.dtu.dk/~jab/publications.html
218
		// www.imm.dtu.dk/~jab/publications.html
218
		int N = kvec_end-kvec_beg;
219
		int N = kvec_end-kvec_beg;
219
		int M = 1<< (CGLA::two_to_what_power(N));
220
		int M = 1<< (CGLA::two_to_what_power(N));
220
		int R = N-(M-1);
221
		int R = N-(M-1);
221
		int left_size  = (M-2)/2;
222
		int left_size  = (M-2)/2;
222
		int right_size = (M-2)/2;
223
		int right_size = (M-2)/2;
223
		if(R < M/2)
224
		if(R < M/2)
224
			{
225
			{
225
				left_size += R;
226
				left_size += R;
226
			}
227
			}
227
		else
228
		else
228
			{
229
			{
229
				left_size += M/2;
230
				left_size += M/2;
230
				right_size += R-M/2;
231
				right_size += R-M/2;
231
			}
232
			}
232
 
233
 
233
		int median = kvec_beg + left_size;
234
		int median = kvec_beg + left_size;
234
 
235
 
235
		// Sort elements but use nth_element (which is cheaper) than
236
		// Sort elements but use nth_element (which is cheaper) than
236
		// a sorting algorithm. All elements to the left of the median
237
		// a sorting algorithm. All elements to the left of the median
237
		// will be smaller than or equal the median. All elements to the right
238
		// will be smaller than or equal the median. All elements to the right
238
		// will be greater than or equal to the median.
239
		// will be greater than or equal to the median.
239
		const Comp comp(disc);
240
		const Comp comp(disc);
240
		std::nth_element(&init_nodes[kvec_beg], 
241
		std::nth_element(&init_nodes[kvec_beg], 
241
										 &init_nodes[median], 
242
										 &init_nodes[median], 
242
										 &init_nodes[kvec_end], comp);
243
										 &init_nodes[kvec_end], comp);
243
 
244
 
244
		// Insert the node in the final data structure.
245
		// Insert the node in the final data structure.
245
		nodes[cur] = init_nodes[median];
246
		nodes[cur] = init_nodes[median];
246
		nodes[cur].dsc = disc;
247
		nodes[cur].dsc = disc;
247
 
248
 
248
		// Recursively build left and right tree.
249
		// Recursively build left and right tree.
249
		if(left_size>0)	
250
		if(left_size>0)	
250
			optimize(2*cur, kvec_beg, median,level+1);
251
			optimize(2*cur, kvec_beg, median,level+1);
251
		
252
		
252
		if(right_size>0) 
253
		if(right_size>0) 
253
			optimize(2*cur+1, median+1, kvec_end,level+1);
254
			optimize(2*cur+1, median+1, kvec_end,level+1);
254
	}
255
	}
255
 
256
 
256
	template<class KeyT, class ValT>
257
	template<class KeyT, class ValT>
257
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
258
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
258
																						ScalarType& dist) const
259
																						ScalarType& dist) const
259
	{
260
	{
260
		int ret_node = 0;
261
		int ret_node = 0;
261
		ScalarType this_dist = nodes[n].dist(p);
262
		ScalarType this_dist = nodes[n].dist(p);
262
 
263
 
263
		if(this_dist<dist)
264
		if(this_dist<dist)
264
			{
265
			{
265
				dist = this_dist;
266
				dist = this_dist;
266
				ret_node = n;
267
				ret_node = n;
267
			}
268
			}
268
		if(nodes[n].dsc != -1)
269
		if(nodes[n].dsc != -1)
269
			{
270
			{
270
				int dsc         = nodes[n].dsc;
271
				int dsc         = nodes[n].dsc;
271
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
272
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
272
				bool left_son   = Comp(dsc)(p,nodes[n].key);
273
				bool left_son   = Comp(dsc)(p,nodes[n].key);
273
 
274
 
274
				if(left_son||dsc_dist<dist)
275
				if(left_son||dsc_dist<dist)
275
					{
276
					{
276
						int left_child = 2*n;
277
						int left_child = 2*n;
277
						if(left_child < nodes.size())
278
						if(left_child < nodes.size())
278
							if(int nl=closest_point_priv(left_child, p, dist))
279
							if(int nl=closest_point_priv(left_child, p, dist))
279
								ret_node = nl;
280
								ret_node = nl;
280
					}
281
					}
281
				if(!left_son||dsc_dist<dist)
282
				if(!left_son||dsc_dist<dist)
282
					{
283
					{
283
						int right_child = 2*n+1;
284
						int right_child = 2*n+1;
284
						if(right_child < nodes.size())
285
						if(right_child < nodes.size())
285
							if(int nr=closest_point_priv(right_child, p, dist))
286
							if(int nr=closest_point_priv(right_child, p, dist))
286
								ret_node = nr;
287
								ret_node = nr;
287
					}
288
					}
288
			}
289
			}
289
		return ret_node;
290
		return ret_node;
290
	}
291
	}
291
 
292
 
292
	template<class KeyT, class ValT>
293
	template<class KeyT, class ValT>
293
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
294
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
294
																				 const KeyType& p, 
295
																				 const KeyType& p, 
295
																				 const ScalarType& dist,
296
																				 const ScalarType& dist,
296
																				 std::vector<KeyT>& keys,
297
																				 std::vector<KeyT>& keys,
297
																				 std::vector<ValT>& vals) const
298
																				 std::vector<ValT>& vals) const
298
	{
299
	{
299
		ScalarType this_dist = nodes[n].dist(p);
300
		ScalarType this_dist = nodes[n].dist(p);
300
		assert(n<nodes.size());
301
		assert(n<nodes.size());
301
		if(this_dist<dist)
302
		if(this_dist<dist)
302
			{
303
			{
303
				keys.push_back(nodes[n].key);
304
				keys.push_back(nodes[n].key);
304
				vals.push_back(nodes[n].val);
305
				vals.push_back(nodes[n].val);
305
			}
306
			}
306
		if(nodes[n].dsc != -1)
307
		if(nodes[n].dsc != -1)
307
			{
308
			{
308
				const int dsc         = nodes[n].dsc;
309
				const int dsc         = nodes[n].dsc;
309
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
310
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
310
 
311
 
311
				bool left_son = Comp(dsc)(p,nodes[n].key);
312
				bool left_son = Comp(dsc)(p,nodes[n].key);
312
 
313
 
313
				if(left_son||dsc_dist<dist)
314
				if(left_son||dsc_dist<dist)
314
					{
315
					{
315
						int left_child = 2*n;
316
						int left_child = 2*n;
316
						if(left_child < nodes.size())
317
						if(left_child < nodes.size())
317
							in_sphere_priv(left_child, p, dist, keys, vals);
318
							in_sphere_priv(left_child, p, dist, keys, vals);
318
					}
319
					}
319
				if(!left_son||dsc_dist<dist)
320
				if(!left_son||dsc_dist<dist)
320
					{
321
					{
321
						int right_child = 2*n+1;
322
						int right_child = 2*n+1;
322
						if(right_child < nodes.size())
323
						if(right_child < nodes.size())
323
							in_sphere_priv(right_child, p, dist, keys, vals);
324
							in_sphere_priv(right_child, p, dist, keys, vals);
324
					}
325
					}
325
			}
326
			}
326
	}
327
	}
327
}
328
}
328
namespace GEO = Geometry;
329
namespace GEO = Geometry;
329
 
330
 
330
#if (_MSC_VER >= 1200)
331
#if (_MSC_VER >= 1200)
331
#pragma warning (pop)
332
#pragma warning (pop)
332
#endif
333
#endif
333
 
334
 
334
 
335
 
335
#endif
336
#endif
336
 
337