Subversion Repositories gelsvn

Rev

Rev 61 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 61 Rev 67
1
/** Author: J. Andreas Bærentzen, (c) 2003
1
/** Author: J. Andreas Bærentzen, (c) 2003
2
		This is a simple KDTree data structure. It is written using templates, and
2
		This is a simple KDTree data structure. It is written using templates, and
3
		it is parametrized by both key and value. 
3
		it is parametrized by both key and value. 
4
*/
4
*/
5
 
5
 
6
#ifndef __KDTREE_H
6
#ifndef __KDTREE_H
7
#define __KDTREE_H
7
#define __KDTREE_H
8
 
8
 
9
#include <cmath>
9
#include <cmath>
10
#include <iostream>
10
#include <iostream>
11
#include <vector>
11
#include <vector>
12
#include <algorithm>
12
#include <algorithm>
13
#include "CGLA/CGLA.h"
13
#include "CGLA/CGLA.h"
14
 
14
 
15
namespace Geometry
15
namespace Geometry
16
{
16
{
17
	/** A classic K-D tree. 
17
	/** A classic K-D tree. 
18
			A K-D tree is a good data structure for storing points in space
18
			A K-D tree is a good data structure for storing points in space
19
			and for nearest neighbour queries. It is basically a generalized 
19
			and for nearest neighbour queries. It is basically a generalized 
20
			binary tree in K dimensions. */
20
			binary tree in K dimensions. */
21
	template<class KeyT, class ValT>
21
	template<class KeyT, class ValT>
22
	class KDTree
22
	class KDTree
23
	{
23
	{
24
		typedef typename KeyT::ScalarType ScalarType;
24
		typedef typename KeyT::ScalarType ScalarType;
25
		typedef KeyT KeyType;
25
		typedef KeyT KeyType;
26
		typedef std::vector<KeyT> KeyVectorType;
26
		typedef std::vector<KeyT> KeyVectorType;
27
		typedef std::vector<ValT> ValVectorType;
27
		typedef std::vector<ValT> ValVectorType;
28
	
28
	
29
		/// KDNode struct represents node in KD tree
29
		/// KDNode struct represents node in KD tree
30
		struct KDNode
30
		struct KDNode
31
		{
31
		{
32
			KeyT key;
32
			KeyT key;
33
			ValT val;
33
			ValT val;
34
			short dsc;
34
			short dsc;
35
 
35
 
36
			KDNode(): dsc(0) {}
36
			KDNode(): dsc(0) {}
37
 
37
 
38
			KDNode(const KeyT& _key, const ValT& _val):
38
			KDNode(const KeyT& _key, const ValT& _val):
39
				key(_key), val(_val), dsc(-1) {}
39
				key(_key), val(_val), dsc(-1) {}
40
 
40
 
41
			ScalarType dist(const KeyType& p) const 
41
			ScalarType dist(const KeyType& p) const 
42
			{
42
			{
43
				KeyType dist_vec = p;
43
				KeyType dist_vec = p;
44
				dist_vec  -= key;
44
				dist_vec  -= key;
45
				return dot(dist_vec, dist_vec);
45
				return dot(dist_vec, dist_vec);
46
			}
46
			}
47
		};
47
		};
48
 
48
 
49
		typedef std::vector<KDNode> NodeVecType;
49
		typedef std::vector<KDNode> NodeVecType;
50
		NodeVecType init_nodes;
50
		NodeVecType init_nodes;
51
		NodeVecType nodes;
51
		NodeVecType nodes;
52
		bool is_built;
52
		bool is_built;
53
 
53
 
54
		/// The greatest depth of KD tree.
54
		/// The greatest depth of KD tree.
55
		int max_depth;
55
		int max_depth;
56
 
56
 
57
		/// Total number of elements in tree
57
		/// Total number of elements in tree
58
		int elements;
58
		int elements;
59
 
59
 
60
		/** Comp is a class used for comparing two keys. Comp is constructed
60
		/** Comp is a class used for comparing two keys. Comp is constructed
61
				with the discriminator - i.e. the coordinate of the key that is used
61
				with the discriminator - i.e. the coordinate of the key that is used
62
				for comparing keys - Comp objects are passed to the sort algorithm.*/
62
				for comparing keys - Comp objects are passed to the sort algorithm.*/
63
		class Comp
63
		class Comp
64
		{
64
		{
65
			const int dsc;
65
			const int dsc;
66
		public:
66
		public:
67
			Comp(int _dsc): dsc(_dsc) {}
67
			Comp(int _dsc): dsc(_dsc) {}
68
			bool operator()(const KeyType& k0, const KeyType& k1) const
68
			bool operator()(const KeyType& k0, const KeyType& k1) const
69
			{
69
			{
70
				int dim=KeyType::get_dim();
70
				int dim=KeyType::get_dim();
71
				for(int i=0;i<dim;i++)
71
				for(int i=0;i<dim;i++)
72
					{
72
					{
73
						int j=(dsc+i)%dim;
73
						int j=(dsc+i)%dim;
74
						if(k0[j]<k1[j])
74
						if(k0[j]<k1[j])
75
							return true;
75
							return true;
76
						if(k0[j]>k1[j])
76
						if(k0[j]>k1[j])
77
							return false;
77
							return false;
78
					}
78
					}
79
				return false;
79
				return false;
80
			}
80
			}
81
 
81
 
82
			bool operator()(const KDNode& k0, const KDNode& k1) const
82
			bool operator()(const KDNode& k0, const KDNode& k1) const
83
			{
83
			{
84
				return (*this)(k0.key,k1.key);
84
				return (*this)(k0.key,k1.key);
85
			}
85
			}
86
		};
86
		};
87
 
87
 
88
		/// The dimension -- K
88
		/// The dimension -- K
89
		const int DIM;
89
		const int DIM;
90
 
90
 
91
		/** Passed a vector of keys, this function will construct an optimal tree.
91
		/** Passed a vector of keys, this function will construct an optimal tree.
92
				It is called recursively - second argument is level in tree. */
92
				It is called recursively - second argument is level in tree. */
93
		void optimize(int, int, int, int);
93
		void optimize(int, int, int, int);
94
 
94
 
95
		/** Finde nearest neighbour. */
95
		/** Finde nearest neighbour. */
96
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
96
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
97
 
97
 
98
							
98
							
99
		void in_sphere_priv(int n, 
99
		void in_sphere_priv(int n, 
100
												const KeyType& p, 
100
												const KeyType& p, 
101
												const ScalarType& dist,
101
												const ScalarType& dist,
102
												std::vector<KeyT>& keys,
102
												std::vector<KeyT>& keys,
103
												std::vector<ValT>& vals) const;
103
												std::vector<ValT>& vals) const;
104
	
104
	
105
		/** Finds the optimal discriminator. There are more ways, but this 
105
		/** Finds the optimal discriminator. There are more ways, but this 
106
				function traverses the vector and finds out what dimension has
106
				function traverses the vector and finds out what dimension has
107
				the greatest difference between min and max element. That dimension
107
				the greatest difference between min and max element. That dimension
108
				is used for discriminator */
108
				is used for discriminator */
109
		int opt_disc(int,int) const;
109
		int opt_disc(int,int) const;
110
	
110
	
111
	public:
111
	public:
112
 
112
 
113
		/** Build tree from vector of keys passed as argument. */
113
		/** Build tree from vector of keys passed as argument. */
114
		KDTree():
114
		KDTree():
115
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
115
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
116
		{
116
		{
117
		}
117
		}
118
 
118
 
119
		/** Insert a key value pair into the tree. Note that the tree needs to 
119
		/** Insert a key value pair into the tree. Note that the tree needs to 
120
				be built - by calling the build function - before you can search. */
120
				be built - by calling the build function - before you can search. */
121
		void insert(const KeyT& key, const ValT& val)
121
		void insert(const KeyT& key, const ValT& val)
122
		{
122
		{
123
			assert(!is_built);
123
			assert(!is_built);
124
			init_nodes.push_back(KDNode(key,val));
124
			init_nodes.push_back(KDNode(key,val));
125
		}
125
		}
126
 
126
 
127
		/** Build the tree. After this function have been called, it is no longer 
127
		/** Build the tree. After this function have been called, it is no longer 
128
				legal to insert elements, but you can perform searches. */
128
				legal to insert elements, but you can perform searches. */
129
		void build()
129
		void build()
130
		{
130
		{
131
			assert(!is_built);
131
			assert(!is_built);
132
			nodes.resize(init_nodes.size()+1);
132
			nodes.resize(init_nodes.size()+1);
133
			if(init_nodes.size() > 0)	
133
			if(init_nodes.size() > 0)	
134
				optimize(1,0,init_nodes.size(),0);
134
				optimize(1,0,init_nodes.size(),0);
135
			NodeVecType v(0);
135
			NodeVecType v(0);
136
			init_nodes.swap(v);
136
			init_nodes.swap(v);
137
			is_built = true;
137
			is_built = true;
138
		}
138
		}
139
 
139
 
140
		/** Find the key value pair closest to the key given as first 
140
		/** Find the key value pair closest to the key given as first 
141
				argument. The second argument is the maximum search distance.
141
				argument. The second argument is the maximum search distance.
142
				The final two arguments contain the closest key and its 
142
				The final two arguments contain the closest key and its 
143
				associated value upon return. */
143
				associated value upon return. */
144
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
144
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
145
		{
145
		{
146
			assert(is_built);
146
			assert(is_built);
147
			float max_sq_dist = CGLA::sqr(dist);
147
			float max_sq_dist = CGLA::sqr(dist);
148
			if(int n = closest_point_priv(1, p, max_sq_dist))
148
			if(int n = closest_point_priv(1, p, max_sq_dist))
149
				{
149
				{
150
					k = nodes[n].key;
150
					k = nodes[n].key;
151
					v = nodes[n].val;
151
					v = nodes[n].val;
152
					dist = std::sqrt(max_sq_dist);
152
					dist = std::sqrt(max_sq_dist);
153
					return true;
153
					return true;
154
				}
154
				}
155
			return false;
155
			return false;
156
		}
156
		}
157
 
157
 
158
		/** Find all the elements within a given radius (second argument) of
158
		/** Find all the elements within a given radius (second argument) of
159
				the key (first argument). The key value pairs inside the sphere are
159
				the key (first argument). The key value pairs inside the sphere are
160
				returned in a pair of vectors passed as the two last arguments. */
160
				returned in a pair of vectors passed as the two last arguments. */
161
		int in_sphere(const KeyType& p, 
161
		int in_sphere(const KeyType& p, 
162
									float dist,
162
									float dist,
163
									std::vector<KeyT>& keys,
163
									std::vector<KeyT>& keys,
164
									std::vector<ValT>& vals) const
164
									std::vector<ValT>& vals) const
165
		{
165
		{
166
			assert(is_built);
166
			assert(is_built);
167
			float max_sq_dist = CGLA::sqr(dist);
167
			float max_sq_dist = CGLA::sqr(dist);
168
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
168
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
169
			return keys.size();
169
			return keys.size();
170
		}
170
		}
171
		
171
		
172
 
172
 
173
	};
173
	};
174
 
174
 
175
	template<class KeyT, class ValT>
175
	template<class KeyT, class ValT>
176
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
176
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
177
																	int kvec_end) const 
177
																	int kvec_end) const 
178
	{
178
	{
179
		KeyType vmin = init_nodes[kvec_beg].key;
179
		KeyType vmin = init_nodes[kvec_beg].key;
180
		KeyType vmax = init_nodes[kvec_beg].key;
180
		KeyType vmax = init_nodes[kvec_beg].key;
181
		for(int i=kvec_beg;i<kvec_end;i++)
181
		for(int i=kvec_beg;i<kvec_end;i++)
182
			{
182
			{
183
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
183
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
184
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
184
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
185
			}
185
			}
186
		int od=0;
186
		int od=0;
187
		KeyType ave_v = vmax-vmin;
187
		KeyType ave_v = vmax-vmin;
188
		for(int i=1;i<KeyType::get_dim();i++)
188
		for(int i=1;i<KeyType::get_dim();i++)
189
			if(ave_v[i]>ave_v[od]) od = i;
189
			if(ave_v[i]>ave_v[od]) od = i;
190
		return od;
190
		return od;
191
	} 
191
	} 
192
 
192
 
193
	template<class KeyT, class ValT>
193
	template<class KeyT, class ValT>
194
	void KDTree<KeyT,ValT>::optimize(int cur,
194
	void KDTree<KeyT,ValT>::optimize(int cur,
195
																	 int kvec_beg,  
195
																	 int kvec_beg,  
196
																	 int kvec_end,  
196
																	 int kvec_end,  
197
																	 int level)
197
																	 int level)
198
	{
198
	{
199
		// Assert that we are not inserting beyond capacity.
199
		// Assert that we are not inserting beyond capacity.
200
		assert(cur < nodes.size());
200
		assert(cur < nodes.size());
201
 
201
 
202
		// If there is just a single element, we simply insert.
202
		// If there is just a single element, we simply insert.
203
		if(kvec_beg+1==kvec_end) 
203
		if(kvec_beg+1==kvec_end) 
204
			{
204
			{
205
				max_depth  = std::max(level,max_depth);
205
				max_depth  = std::max(level,max_depth);
206
				nodes[cur] = init_nodes[kvec_beg];
206
				nodes[cur] = init_nodes[kvec_beg];
207
				nodes[cur].dsc = -1;
207
				nodes[cur].dsc = -1;
208
				return;
208
				return;
209
			}
209
			}
210
	
210
	
211
		// Find the axis that best separates the data.
211
		// Find the axis that best separates the data.
212
		int disc = opt_disc(kvec_beg, kvec_end);
212
		int disc = opt_disc(kvec_beg, kvec_end);
213
 
213
 
214
		// Compute the median element. See my document on how to do this
214
		// Compute the median element. See my document on how to do this
215
		// www.imm.dtu.dk/~jab/publications.html
215
		// www.imm.dtu.dk/~jab/publications.html
216
		int N = kvec_end-kvec_beg;
216
		int N = kvec_end-kvec_beg;
217
		int M = 1<< (CGLA::two_to_what_power(N));
217
		int M = 1<< (CGLA::two_to_what_power(N));
218
		int R = N-(M-1);
218
		int R = N-(M-1);
219
		int left_size  = (M-2)/2;
219
		int left_size  = (M-2)/2;
220
		int right_size = (M-2)/2;
220
		int right_size = (M-2)/2;
221
		if(R < M/2)
221
		if(R < M/2)
222
			{
222
			{
223
				left_size += R;
223
				left_size += R;
224
			}
224
			}
225
		else
225
		else
226
			{
226
			{
227
				left_size += M/2;
227
				left_size += M/2;
228
				right_size += R-M/2;
228
				right_size += R-M/2;
229
			}
229
			}
230
 
230
 
231
		int median = kvec_beg + left_size;
231
		int median = kvec_beg + left_size;
232
 
232
 
233
		// Sort elements but use nth_element (which is cheaper) than
233
		// Sort elements but use nth_element (which is cheaper) than
234
		// a sorting algorithm. All elements to the left of the median
234
		// a sorting algorithm. All elements to the left of the median
235
		// will be smaller than or equal the median. All elements to the right
235
		// will be smaller than or equal the median. All elements to the right
236
		// will be greater than or equal to the median.
236
		// will be greater than or equal to the median.
237
		const Comp comp(disc);
237
		const Comp comp(disc);
238
		std::nth_element(&init_nodes[kvec_beg], 
238
		std::nth_element(&init_nodes[kvec_beg], 
239
										 &init_nodes[median], 
239
										 &init_nodes[median], 
240
										 &init_nodes[kvec_end], comp);
240
										 &init_nodes[kvec_end], comp);
241
 
241
 
242
		// Insert the node in the final data structure.
242
		// Insert the node in the final data structure.
243
		nodes[cur] = init_nodes[median];
243
		nodes[cur] = init_nodes[median];
244
		nodes[cur].dsc = disc;
244
		nodes[cur].dsc = disc;
245
 
245
 
246
		// Recursively build left and right tree.
246
		// Recursively build left and right tree.
247
		if(left_size>0)	
247
		if(left_size>0)	
248
			optimize(2*cur, kvec_beg, median,level+1);
248
			optimize(2*cur, kvec_beg, median,level+1);
249
		
249
		
250
		if(right_size>0) 
250
		if(right_size>0) 
251
			optimize(2*cur+1, median+1, kvec_end,level+1);
251
			optimize(2*cur+1, median+1, kvec_end,level+1);
252
	}
252
	}
253
 
253
 
254
	template<class KeyT, class ValT>
254
	template<class KeyT, class ValT>
255
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
255
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
256
																						ScalarType& dist) const
256
																						ScalarType& dist) const
257
	{
257
	{
258
		int ret_node = 0;
258
		int ret_node = 0;
259
		ScalarType this_dist = nodes[n].dist(p);
259
		ScalarType this_dist = nodes[n].dist(p);
260
 
260
 
261
		if(this_dist<dist)
261
		if(this_dist<dist)
262
			{
262
			{
263
				dist = this_dist;
263
				dist = this_dist;
264
				ret_node = n;
264
				ret_node = n;
265
			}
265
			}
266
		if(nodes[n].dsc != -1)
266
		if(nodes[n].dsc != -1)
267
			{
267
			{
268
				int dsc         = nodes[n].dsc;
268
				int dsc         = nodes[n].dsc;
269
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
269
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
270
				bool left_son   = Comp(dsc)(p,nodes[n].key);
270
				bool left_son   = Comp(dsc)(p,nodes[n].key);
271
 
271
 
272
				if(left_son||dsc_dist<dist)
272
				if(left_son||dsc_dist<dist)
273
					{
273
					{
274
						int left_child = 2*n;
274
						int left_child = 2*n;
275
						if(left_child < nodes.size())
275
						if(left_child < nodes.size())
276
							if(int nl=closest_point_priv(left_child, p, dist))
276
							if(int nl=closest_point_priv(left_child, p, dist))
277
								ret_node = nl;
277
								ret_node = nl;
278
					}
278
					}
279
				if(!left_son||dsc_dist<dist)
279
				if(!left_son||dsc_dist<dist)
280
					{
280
					{
281
						int right_child = 2*n+1;
281
						int right_child = 2*n+1;
282
						if(right_child < nodes.size())
282
						if(right_child < nodes.size())
283
							if(int nr=closest_point_priv(right_child, p, dist))
283
							if(int nr=closest_point_priv(right_child, p, dist))
284
								ret_node = nr;
284
								ret_node = nr;
285
					}
285
					}
286
			}
286
			}
287
		return ret_node;
287
		return ret_node;
288
	}
288
	}
289
 
289
 
290
	template<class KeyT, class ValT>
290
	template<class KeyT, class ValT>
291
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
291
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
292
																				 const KeyType& p, 
292
																				 const KeyType& p, 
293
																				 const ScalarType& dist,
293
																				 const ScalarType& dist,
294
																				 std::vector<KeyT>& keys,
294
																				 std::vector<KeyT>& keys,
295
																				 std::vector<ValT>& vals) const
295
																				 std::vector<ValT>& vals) const
296
	{
296
	{
297
		ScalarType this_dist = nodes[n].dist(p);
297
		ScalarType this_dist = nodes[n].dist(p);
298
		assert(n<nodes.size());
298
		assert(n<nodes.size());
299
		if(this_dist<dist)
299
		if(this_dist<dist)
300
			{
300
			{
301
				keys.push_back(nodes[n].key);
301
				keys.push_back(nodes[n].key);
302
				vals.push_back(nodes[n].val);
302
				vals.push_back(nodes[n].val);
303
			}
303
			}
304
		if(nodes[n].dsc != -1)
304
		if(nodes[n].dsc != -1)
305
			{
305
			{
306
				const int dsc         = nodes[n].dsc;
306
				const int dsc         = nodes[n].dsc;
307
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
307
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
308
 
308
 
309
				bool left_son = Comp(dsc)(p,nodes[n].key);
309
				bool left_son = Comp(dsc)(p,nodes[n].key);
310
 
310
 
311
				if(left_son||dsc_dist<dist)
311
				if(left_son||dsc_dist<dist)
312
					{
312
					{
313
						int left_child = 2*n;
313
						int left_child = 2*n;
314
						if(left_child < nodes.size())
314
						if(left_child < nodes.size())
315
							in_sphere_priv(left_child, p, dist, keys, vals);
315
							in_sphere_priv(left_child, p, dist, keys, vals);
316
					}
316
					}
317
				if(!left_son||dsc_dist<dist)
317
				if(!left_son||dsc_dist<dist)
318
					{
318
					{
319
						int right_child = 2*n+1;
319
						int right_child = 2*n+1;
320
						if(right_child < nodes.size())
320
						if(right_child < nodes.size())
321
							in_sphere_priv(right_child, p, dist, keys, vals);
321
							in_sphere_priv(right_child, p, dist, keys, vals);
322
					}
322
					}
323
			}
323
			}
324
	}
324
	}
325
}
325
}
326
namespace GEO = Geometry;
326
namespace GEO = Geometry;
327
 
327
 
328
#endif
328
#endif
329
 
329