Subversion Repositories gelsvn

Rev

Rev 125 | Rev 203 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 125 Rev 198
1
#ifndef __KDTREE_H
1
#ifndef __KDTREE_H
2
#define __KDTREE_H
2
#define __KDTREE_H
3
 
3
 
4
#include <cmath>
4
#include <cmath>
5
#include <iostream>
5
#include <iostream>
6
#include <vector>
6
#include <vector>
7
#include <algorithm>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
8
#include "CGLA/CGLA.h"
9
#include "CGLA/ArithVec.h"
9
#include "CGLA/ArithVec.h"
10
 
10
 
-
 
11
#if (_MSC_VER >= 1200)
-
 
12
#pragma push (warning)
-
 
13
#pragma warning (disable: 4018)
-
 
14
#endif
-
 
15
 
11
namespace Geometry
16
namespace Geometry
12
{
17
{
13
	/** \brief A classic K-D tree. 
18
	/** \brief A classic K-D tree. 
14
 
19
 
15
			A K-D tree is a good data structure for storing points in space
20
			A K-D tree is a good data structure for storing points in space
16
			and for nearest neighbour queries. It is basically a generalized 
21
			and for nearest neighbour queries. It is basically a generalized 
17
			binary tree in K dimensions. */
22
			binary tree in K dimensions. */
18
	template<class KeyT, class ValT>
23
	template<class KeyT, class ValT>
19
	class KDTree
24
	class KDTree
20
	{
25
	{
21
		typedef typename KeyT::ScalarType ScalarType;
26
		typedef typename KeyT::ScalarType ScalarType;
22
		typedef KeyT KeyType;
27
		typedef KeyT KeyType;
23
		typedef std::vector<KeyT> KeyVectorType;
28
		typedef std::vector<KeyT> KeyVectorType;
24
		typedef std::vector<ValT> ValVectorType;
29
		typedef std::vector<ValT> ValVectorType;
25
	
30
	
26
		/// KDNode struct represents node in KD tree
31
		/// KDNode struct represents node in KD tree
27
		struct KDNode
32
		struct KDNode
28
		{
33
		{
29
			KeyT key;
34
			KeyT key;
30
			ValT val;
35
			ValT val;
31
			short dsc;
36
			short dsc;
32
 
37
 
33
			KDNode(): dsc(0) {}
38
			KDNode(): dsc(0) {}
34
 
39
 
35
			KDNode(const KeyT& _key, const ValT& _val):
40
			KDNode(const KeyT& _key, const ValT& _val):
36
				key(_key), val(_val), dsc(-1) {}
41
				key(_key), val(_val), dsc(-1) {}
37
 
42
 
38
			ScalarType dist(const KeyType& p) const 
43
			ScalarType dist(const KeyType& p) const 
39
			{
44
			{
40
				KeyType dist_vec = p;
45
				KeyType dist_vec = p;
41
				dist_vec  -= key;
46
				dist_vec  -= key;
42
				return dot(dist_vec, dist_vec);
47
				return dot(dist_vec, dist_vec);
43
			}
48
			}
44
		};
49
		};
45
 
50
 
46
		typedef std::vector<KDNode> NodeVecType;
51
		typedef std::vector<KDNode> NodeVecType;
47
		NodeVecType init_nodes;
52
		NodeVecType init_nodes;
48
		NodeVecType nodes;
53
		NodeVecType nodes;
49
		bool is_built;
54
		bool is_built;
50
 
55
 
51
		/// The greatest depth of KD tree.
56
		/// The greatest depth of KD tree.
52
		int max_depth;
57
		int max_depth;
53
 
58
 
54
		/// Total number of elements in tree
59
		/// Total number of elements in tree
55
		int elements;
60
		int elements;
56
 
61
 
57
		/** Comp is a class used for comparing two keys. Comp is constructed
62
		/** Comp is a class used for comparing two keys. Comp is constructed
58
				with the discriminator - i.e. the coordinate of the key that is used
63
				with the discriminator - i.e. the coordinate of the key that is used
59
				for comparing keys - Comp objects are passed to the sort algorithm.*/
64
				for comparing keys - Comp objects are passed to the sort algorithm.*/
60
		class Comp
65
		class Comp
61
		{
66
		{
62
			const int dsc;
67
			const int dsc;
63
		public:
68
		public:
64
			Comp(int _dsc): dsc(_dsc) {}
69
			Comp(int _dsc): dsc(_dsc) {}
65
			bool operator()(const KeyType& k0, const KeyType& k1) const
70
			bool operator()(const KeyType& k0, const KeyType& k1) const
66
			{
71
			{
67
				int dim=KeyType::get_dim();
72
				int dim=KeyType::get_dim();
68
				for(int i=0;i<dim;i++)
73
				for(int i=0;i<dim;i++)
69
					{
74
					{
70
						int j=(dsc+i)%dim;
75
						int j=(dsc+i)%dim;
71
						if(k0[j]<k1[j])
76
						if(k0[j]<k1[j])
72
							return true;
77
							return true;
73
						if(k0[j]>k1[j])
78
						if(k0[j]>k1[j])
74
							return false;
79
							return false;
75
					}
80
					}
76
				return false;
81
				return false;
77
			}
82
			}
78
 
83
 
79
			bool operator()(const KDNode& k0, const KDNode& k1) const
84
			bool operator()(const KDNode& k0, const KDNode& k1) const
80
			{
85
			{
81
				return (*this)(k0.key,k1.key);
86
				return (*this)(k0.key,k1.key);
82
			}
87
			}
83
		};
88
		};
84
 
89
 
85
		/// The dimension -- K
90
		/// The dimension -- K
86
		const int DIM;
91
		const int DIM;
87
 
92
 
88
		/** Passed a vector of keys, this function will construct an optimal tree.
93
		/** Passed a vector of keys, this function will construct an optimal tree.
89
				It is called recursively - second argument is level in tree. */
94
				It is called recursively - second argument is level in tree. */
90
		void optimize(int, int, int, int);
95
		void optimize(int, int, int, int);
91
 
96
 
92
		/** Finde nearest neighbour. */
97
		/** Finde nearest neighbour. */
93
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
98
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
94
 
99
 
95
							
100
							
96
		void in_sphere_priv(int n, 
101
		void in_sphere_priv(int n, 
97
												const KeyType& p, 
102
												const KeyType& p, 
98
												const ScalarType& dist,
103
												const ScalarType& dist,
99
												std::vector<KeyT>& keys,
104
												std::vector<KeyT>& keys,
100
												std::vector<ValT>& vals) const;
105
												std::vector<ValT>& vals) const;
101
	
106
	
102
		/** Finds the optimal discriminator. There are more ways, but this 
107
		/** Finds the optimal discriminator. There are more ways, but this 
103
				function traverses the vector and finds out what dimension has
108
				function traverses the vector and finds out what dimension has
104
				the greatest difference between min and max element. That dimension
109
				the greatest difference between min and max element. That dimension
105
				is used for discriminator */
110
				is used for discriminator */
106
		int opt_disc(int,int) const;
111
		int opt_disc(int,int) const;
107
	
112
	
108
	public:
113
	public:
109
 
114
 
110
		/** Build tree from vector of keys passed as argument. */
115
		/** Build tree from vector of keys passed as argument. */
111
		KDTree():
116
		KDTree():
112
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
117
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
113
		{
118
		{
114
		}
119
		}
115
 
120
 
116
		/** Insert a key value pair into the tree. Note that the tree needs to 
121
		/** Insert a key value pair into the tree. Note that the tree needs to 
117
				be built - by calling the build function - before you can search. */
122
				be built - by calling the build function - before you can search. */
118
		void insert(const KeyT& key, const ValT& val)
123
		void insert(const KeyT& key, const ValT& val)
119
		{
124
		{
120
			assert(!is_built);
125
			assert(!is_built);
121
			init_nodes.push_back(KDNode(key,val));
126
			init_nodes.push_back(KDNode(key,val));
122
		}
127
		}
123
 
128
 
124
		/** Build the tree. After this function have been called, it is no longer 
129
		/** Build the tree. After this function have been called, it is no longer 
125
				legal to insert elements, but you can perform searches. */
130
				legal to insert elements, but you can perform searches. */
126
		void build()
131
		void build()
127
		{
132
		{
128
			assert(!is_built);
133
			assert(!is_built);
129
			nodes.resize(init_nodes.size()+1);
134
			nodes.resize(init_nodes.size()+1);
130
			if(init_nodes.size() > 0)	
135
			if(init_nodes.size() > 0)	
131
				optimize(1,0,init_nodes.size(),0);
136
				optimize(1,0,init_nodes.size(),0);
132
			NodeVecType v(0);
137
			NodeVecType v(0);
133
			init_nodes.swap(v);
138
			init_nodes.swap(v);
134
			is_built = true;
139
			is_built = true;
135
		}
140
		}
136
 
141
 
137
		/** Find the key value pair closest to the key given as first 
142
		/** Find the key value pair closest to the key given as first 
138
				argument. The second argument is the maximum search distance.
143
				argument. The second argument is the maximum search distance.
139
				The final two arguments contain the closest key and its 
144
				The final two arguments contain the closest key and its 
140
				associated value upon return. */
145
				associated value upon return. */
141
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
146
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
142
		{
147
		{
143
			assert(is_built);
148
			assert(is_built);
144
			float max_sq_dist = CGLA::sqr(dist);
149
			float max_sq_dist = CGLA::sqr(dist);
145
			if(int n = closest_point_priv(1, p, max_sq_dist))
150
			if(int n = closest_point_priv(1, p, max_sq_dist))
146
				{
151
				{
147
					k = nodes[n].key;
152
					k = nodes[n].key;
148
					v = nodes[n].val;
153
					v = nodes[n].val;
149
					dist = std::sqrt(max_sq_dist);
154
					dist = std::sqrt(max_sq_dist);
150
					return true;
155
					return true;
151
				}
156
				}
152
			return false;
157
			return false;
153
		}
158
		}
154
 
159
 
155
		/** Find all the elements within a given radius (second argument) of
160
		/** Find all the elements within a given radius (second argument) of
156
				the key (first argument). The key value pairs inside the sphere are
161
				the key (first argument). The key value pairs inside the sphere are
157
				returned in a pair of vectors passed as the two last arguments. */
162
				returned in a pair of vectors passed as the two last arguments. */
158
		int in_sphere(const KeyType& p, 
163
		int in_sphere(const KeyType& p, 
159
									float dist,
164
									float dist,
160
									std::vector<KeyT>& keys,
165
									std::vector<KeyT>& keys,
161
									std::vector<ValT>& vals) const
166
									std::vector<ValT>& vals) const
162
		{
167
		{
163
			assert(is_built);
168
			assert(is_built);
164
			float max_sq_dist = CGLA::sqr(dist);
169
			float max_sq_dist = CGLA::sqr(dist);
165
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
170
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
166
			return keys.size();
171
			return keys.size();
167
		}
172
		}
168
		
173
		
169
 
174
 
170
	};
175
	};
171
 
176
 
172
	template<class KeyT, class ValT>
177
	template<class KeyT, class ValT>
173
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
178
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
174
																	int kvec_end) const 
179
																	int kvec_end) const 
175
	{
180
	{
176
		KeyType vmin = init_nodes[kvec_beg].key;
181
		KeyType vmin = init_nodes[kvec_beg].key;
177
		KeyType vmax = init_nodes[kvec_beg].key;
182
		KeyType vmax = init_nodes[kvec_beg].key;
178
		for(int i=kvec_beg;i<kvec_end;i++)
183
		for(int i=kvec_beg;i<kvec_end;i++)
179
			{
184
			{
180
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
185
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
181
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
186
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
182
			}
187
			}
183
		int od=0;
188
		int od=0;
184
		KeyType ave_v = vmax-vmin;
189
		KeyType ave_v = vmax-vmin;
185
		for(int i=1;i<KeyType::get_dim();i++)
190
		for(int i=1;i<KeyType::get_dim();i++)
186
			if(ave_v[i]>ave_v[od]) od = i;
191
			if(ave_v[i]>ave_v[od]) od = i;
187
		return od;
192
		return od;
188
	} 
193
	} 
189
 
194
 
190
	template<class KeyT, class ValT>
195
	template<class KeyT, class ValT>
191
	void KDTree<KeyT,ValT>::optimize(int cur,
196
	void KDTree<KeyT,ValT>::optimize(int cur,
192
																	 int kvec_beg,  
197
																	 int kvec_beg,  
193
																	 int kvec_end,  
198
																	 int kvec_end,  
194
																	 int level)
199
																	 int level)
195
	{
200
	{
196
		// Assert that we are not inserting beyond capacity.
201
		// Assert that we are not inserting beyond capacity.
197
		assert(cur < nodes.size());
202
		assert(cur < nodes.size());
198
 
203
 
199
		// If there is just a single element, we simply insert.
204
		// If there is just a single element, we simply insert.
200
		if(kvec_beg+1==kvec_end) 
205
		if(kvec_beg+1==kvec_end) 
201
			{
206
			{
202
				max_depth  = std::max(level,max_depth);
207
				max_depth  = std::max(level,max_depth);
203
				nodes[cur] = init_nodes[kvec_beg];
208
				nodes[cur] = init_nodes[kvec_beg];
204
				nodes[cur].dsc = -1;
209
				nodes[cur].dsc = -1;
205
				return;
210
				return;
206
			}
211
			}
207
	
212
	
208
		// Find the axis that best separates the data.
213
		// Find the axis that best separates the data.
209
		int disc = opt_disc(kvec_beg, kvec_end);
214
		int disc = opt_disc(kvec_beg, kvec_end);
210
 
215
 
211
		// Compute the median element. See my document on how to do this
216
		// Compute the median element. See my document on how to do this
212
		// www.imm.dtu.dk/~jab/publications.html
217
		// www.imm.dtu.dk/~jab/publications.html
213
		int N = kvec_end-kvec_beg;
218
		int N = kvec_end-kvec_beg;
214
		int M = 1<< (CGLA::two_to_what_power(N));
219
		int M = 1<< (CGLA::two_to_what_power(N));
215
		int R = N-(M-1);
220
		int R = N-(M-1);
216
		int left_size  = (M-2)/2;
221
		int left_size  = (M-2)/2;
217
		int right_size = (M-2)/2;
222
		int right_size = (M-2)/2;
218
		if(R < M/2)
223
		if(R < M/2)
219
			{
224
			{
220
				left_size += R;
225
				left_size += R;
221
			}
226
			}
222
		else
227
		else
223
			{
228
			{
224
				left_size += M/2;
229
				left_size += M/2;
225
				right_size += R-M/2;
230
				right_size += R-M/2;
226
			}
231
			}
227
 
232
 
228
		int median = kvec_beg + left_size;
233
		int median = kvec_beg + left_size;
229
 
234
 
230
		// Sort elements but use nth_element (which is cheaper) than
235
		// Sort elements but use nth_element (which is cheaper) than
231
		// a sorting algorithm. All elements to the left of the median
236
		// a sorting algorithm. All elements to the left of the median
232
		// will be smaller than or equal the median. All elements to the right
237
		// will be smaller than or equal the median. All elements to the right
233
		// will be greater than or equal to the median.
238
		// will be greater than or equal to the median.
234
		const Comp comp(disc);
239
		const Comp comp(disc);
235
		std::nth_element(&init_nodes[kvec_beg], 
240
		std::nth_element(&init_nodes[kvec_beg], 
236
										 &init_nodes[median], 
241
										 &init_nodes[median], 
237
										 &init_nodes[kvec_end], comp);
242
										 &init_nodes[kvec_end], comp);
238
 
243
 
239
		// Insert the node in the final data structure.
244
		// Insert the node in the final data structure.
240
		nodes[cur] = init_nodes[median];
245
		nodes[cur] = init_nodes[median];
241
		nodes[cur].dsc = disc;
246
		nodes[cur].dsc = disc;
242
 
247
 
243
		// Recursively build left and right tree.
248
		// Recursively build left and right tree.
244
		if(left_size>0)	
249
		if(left_size>0)	
245
			optimize(2*cur, kvec_beg, median,level+1);
250
			optimize(2*cur, kvec_beg, median,level+1);
246
		
251
		
247
		if(right_size>0) 
252
		if(right_size>0) 
248
			optimize(2*cur+1, median+1, kvec_end,level+1);
253
			optimize(2*cur+1, median+1, kvec_end,level+1);
249
	}
254
	}
250
 
255
 
251
	template<class KeyT, class ValT>
256
	template<class KeyT, class ValT>
252
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
257
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
253
																						ScalarType& dist) const
258
																						ScalarType& dist) const
254
	{
259
	{
255
		int ret_node = 0;
260
		int ret_node = 0;
256
		ScalarType this_dist = nodes[n].dist(p);
261
		ScalarType this_dist = nodes[n].dist(p);
257
 
262
 
258
		if(this_dist<dist)
263
		if(this_dist<dist)
259
			{
264
			{
260
				dist = this_dist;
265
				dist = this_dist;
261
				ret_node = n;
266
				ret_node = n;
262
			}
267
			}
263
		if(nodes[n].dsc != -1)
268
		if(nodes[n].dsc != -1)
264
			{
269
			{
265
				int dsc         = nodes[n].dsc;
270
				int dsc         = nodes[n].dsc;
266
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
271
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
267
				bool left_son   = Comp(dsc)(p,nodes[n].key);
272
				bool left_son   = Comp(dsc)(p,nodes[n].key);
268
 
273
 
269
				if(left_son||dsc_dist<dist)
274
				if(left_son||dsc_dist<dist)
270
					{
275
					{
271
						int left_child = 2*n;
276
						int left_child = 2*n;
272
						if(left_child < nodes.size())
277
						if(left_child < nodes.size())
273
							if(int nl=closest_point_priv(left_child, p, dist))
278
							if(int nl=closest_point_priv(left_child, p, dist))
274
								ret_node = nl;
279
								ret_node = nl;
275
					}
280
					}
276
				if(!left_son||dsc_dist<dist)
281
				if(!left_son||dsc_dist<dist)
277
					{
282
					{
278
						int right_child = 2*n+1;
283
						int right_child = 2*n+1;
279
						if(right_child < nodes.size())
284
						if(right_child < nodes.size())
280
							if(int nr=closest_point_priv(right_child, p, dist))
285
							if(int nr=closest_point_priv(right_child, p, dist))
281
								ret_node = nr;
286
								ret_node = nr;
282
					}
287
					}
283
			}
288
			}
284
		return ret_node;
289
		return ret_node;
285
	}
290
	}
286
 
291
 
287
	template<class KeyT, class ValT>
292
	template<class KeyT, class ValT>
288
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
293
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
289
																				 const KeyType& p, 
294
																				 const KeyType& p, 
290
																				 const ScalarType& dist,
295
																				 const ScalarType& dist,
291
																				 std::vector<KeyT>& keys,
296
																				 std::vector<KeyT>& keys,
292
																				 std::vector<ValT>& vals) const
297
																				 std::vector<ValT>& vals) const
293
	{
298
	{
294
		ScalarType this_dist = nodes[n].dist(p);
299
		ScalarType this_dist = nodes[n].dist(p);
295
		assert(n<nodes.size());
300
		assert(n<nodes.size());
296
		if(this_dist<dist)
301
		if(this_dist<dist)
297
			{
302
			{
298
				keys.push_back(nodes[n].key);
303
				keys.push_back(nodes[n].key);
299
				vals.push_back(nodes[n].val);
304
				vals.push_back(nodes[n].val);
300
			}
305
			}
301
		if(nodes[n].dsc != -1)
306
		if(nodes[n].dsc != -1)
302
			{
307
			{
303
				const int dsc         = nodes[n].dsc;
308
				const int dsc         = nodes[n].dsc;
304
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
309
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
305
 
310
 
306
				bool left_son = Comp(dsc)(p,nodes[n].key);
311
				bool left_son = Comp(dsc)(p,nodes[n].key);
307
 
312
 
308
				if(left_son||dsc_dist<dist)
313
				if(left_son||dsc_dist<dist)
309
					{
314
					{
310
						int left_child = 2*n;
315
						int left_child = 2*n;
311
						if(left_child < nodes.size())
316
						if(left_child < nodes.size())
312
							in_sphere_priv(left_child, p, dist, keys, vals);
317
							in_sphere_priv(left_child, p, dist, keys, vals);
313
					}
318
					}
314
				if(!left_son||dsc_dist<dist)
319
				if(!left_son||dsc_dist<dist)
315
					{
320
					{
316
						int right_child = 2*n+1;
321
						int right_child = 2*n+1;
317
						if(right_child < nodes.size())
322
						if(right_child < nodes.size())
318
							in_sphere_priv(right_child, p, dist, keys, vals);
323
							in_sphere_priv(right_child, p, dist, keys, vals);
319
					}
324
					}
320
			}
325
			}
321
	}
326
	}
322
}
327
}
323
namespace GEO = Geometry;
328
namespace GEO = Geometry;
324
 
329
 
-
 
330
#if (_MSC_VER >= 1200)
-
 
331
#pragma pop (warning)
-
 
332
#endif
-
 
333
 
-
 
334
 
325
#endif
335
#endif
326
 
336