Subversion Repositories gelsvn

Rev

Rev 89 | Rev 198 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 89 Rev 125
1
#ifndef __KDTREE_H
1
#ifndef __KDTREE_H
2
#define __KDTREE_H
2
#define __KDTREE_H
3
 
3
 
4
#include <cmath>
4
#include <cmath>
5
#include <iostream>
5
#include <iostream>
6
#include <vector>
6
#include <vector>
7
#include <algorithm>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
8
#include "CGLA/CGLA.h"
-
 
9
#include "CGLA/ArithVec.h"
9
 
10
 
10
namespace Geometry
11
namespace Geometry
11
{
12
{
12
	/** \brief A classic K-D tree. 
13
	/** \brief A classic K-D tree. 
13
 
14
 
14
			A K-D tree is a good data structure for storing points in space
15
			A K-D tree is a good data structure for storing points in space
15
			and for nearest neighbour queries. It is basically a generalized 
16
			and for nearest neighbour queries. It is basically a generalized 
16
			binary tree in K dimensions. */
17
			binary tree in K dimensions. */
17
	template<class KeyT, class ValT>
18
	template<class KeyT, class ValT>
18
	class KDTree
19
	class KDTree
19
	{
20
	{
20
		typedef typename KeyT::ScalarType ScalarType;
21
		typedef typename KeyT::ScalarType ScalarType;
21
		typedef KeyT KeyType;
22
		typedef KeyT KeyType;
22
		typedef std::vector<KeyT> KeyVectorType;
23
		typedef std::vector<KeyT> KeyVectorType;
23
		typedef std::vector<ValT> ValVectorType;
24
		typedef std::vector<ValT> ValVectorType;
24
	
25
	
25
		/// KDNode struct represents node in KD tree
26
		/// KDNode struct represents node in KD tree
26
		struct KDNode
27
		struct KDNode
27
		{
28
		{
28
			KeyT key;
29
			KeyT key;
29
			ValT val;
30
			ValT val;
30
			short dsc;
31
			short dsc;
31
 
32
 
32
			KDNode(): dsc(0) {}
33
			KDNode(): dsc(0) {}
33
 
34
 
34
			KDNode(const KeyT& _key, const ValT& _val):
35
			KDNode(const KeyT& _key, const ValT& _val):
35
				key(_key), val(_val), dsc(-1) {}
36
				key(_key), val(_val), dsc(-1) {}
36
 
37
 
37
			ScalarType dist(const KeyType& p) const 
38
			ScalarType dist(const KeyType& p) const 
38
			{
39
			{
39
				KeyType dist_vec = p;
40
				KeyType dist_vec = p;
40
				dist_vec  -= key;
41
				dist_vec  -= key;
41
				return dot(dist_vec, dist_vec);
42
				return dot(dist_vec, dist_vec);
42
			}
43
			}
43
		};
44
		};
44
 
45
 
45
		typedef std::vector<KDNode> NodeVecType;
46
		typedef std::vector<KDNode> NodeVecType;
46
		NodeVecType init_nodes;
47
		NodeVecType init_nodes;
47
		NodeVecType nodes;
48
		NodeVecType nodes;
48
		bool is_built;
49
		bool is_built;
49
 
50
 
50
		/// The greatest depth of KD tree.
51
		/// The greatest depth of KD tree.
51
		int max_depth;
52
		int max_depth;
52
 
53
 
53
		/// Total number of elements in tree
54
		/// Total number of elements in tree
54
		int elements;
55
		int elements;
55
 
56
 
56
		/** Comp is a class used for comparing two keys. Comp is constructed
57
		/** Comp is a class used for comparing two keys. Comp is constructed
57
				with the discriminator - i.e. the coordinate of the key that is used
58
				with the discriminator - i.e. the coordinate of the key that is used
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
		class Comp
60
		class Comp
60
		{
61
		{
61
			const int dsc;
62
			const int dsc;
62
		public:
63
		public:
63
			Comp(int _dsc): dsc(_dsc) {}
64
			Comp(int _dsc): dsc(_dsc) {}
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			{
66
			{
66
				int dim=KeyType::get_dim();
67
				int dim=KeyType::get_dim();
67
				for(int i=0;i<dim;i++)
68
				for(int i=0;i<dim;i++)
68
					{
69
					{
69
						int j=(dsc+i)%dim;
70
						int j=(dsc+i)%dim;
70
						if(k0[j]<k1[j])
71
						if(k0[j]<k1[j])
71
							return true;
72
							return true;
72
						if(k0[j]>k1[j])
73
						if(k0[j]>k1[j])
73
							return false;
74
							return false;
74
					}
75
					}
75
				return false;
76
				return false;
76
			}
77
			}
77
 
78
 
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			{
80
			{
80
				return (*this)(k0.key,k1.key);
81
				return (*this)(k0.key,k1.key);
81
			}
82
			}
82
		};
83
		};
83
 
84
 
84
		/// The dimension -- K
85
		/// The dimension -- K
85
		const int DIM;
86
		const int DIM;
86
 
87
 
87
		/** Passed a vector of keys, this function will construct an optimal tree.
88
		/** Passed a vector of keys, this function will construct an optimal tree.
88
				It is called recursively - second argument is level in tree. */
89
				It is called recursively - second argument is level in tree. */
89
		void optimize(int, int, int, int);
90
		void optimize(int, int, int, int);
90
 
91
 
91
		/** Finde nearest neighbour. */
92
		/** Finde nearest neighbour. */
92
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
93
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
93
 
94
 
94
							
95
							
95
		void in_sphere_priv(int n, 
96
		void in_sphere_priv(int n, 
96
												const KeyType& p, 
97
												const KeyType& p, 
97
												const ScalarType& dist,
98
												const ScalarType& dist,
98
												std::vector<KeyT>& keys,
99
												std::vector<KeyT>& keys,
99
												std::vector<ValT>& vals) const;
100
												std::vector<ValT>& vals) const;
100
	
101
	
101
		/** Finds the optimal discriminator. There are more ways, but this 
102
		/** Finds the optimal discriminator. There are more ways, but this 
102
				function traverses the vector and finds out what dimension has
103
				function traverses the vector and finds out what dimension has
103
				the greatest difference between min and max element. That dimension
104
				the greatest difference between min and max element. That dimension
104
				is used for discriminator */
105
				is used for discriminator */
105
		int opt_disc(int,int) const;
106
		int opt_disc(int,int) const;
106
	
107
	
107
	public:
108
	public:
108
 
109
 
109
		/** Build tree from vector of keys passed as argument. */
110
		/** Build tree from vector of keys passed as argument. */
110
		KDTree():
111
		KDTree():
111
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
112
			is_built(false), max_depth(0), DIM(KeyType::get_dim()), elements(0)
112
		{
113
		{
113
		}
114
		}
114
 
115
 
115
		/** Insert a key value pair into the tree. Note that the tree needs to 
116
		/** Insert a key value pair into the tree. Note that the tree needs to 
116
				be built - by calling the build function - before you can search. */
117
				be built - by calling the build function - before you can search. */
117
		void insert(const KeyT& key, const ValT& val)
118
		void insert(const KeyT& key, const ValT& val)
118
		{
119
		{
119
			assert(!is_built);
120
			assert(!is_built);
120
			init_nodes.push_back(KDNode(key,val));
121
			init_nodes.push_back(KDNode(key,val));
121
		}
122
		}
122
 
123
 
123
		/** Build the tree. After this function have been called, it is no longer 
124
		/** Build the tree. After this function have been called, it is no longer 
124
				legal to insert elements, but you can perform searches. */
125
				legal to insert elements, but you can perform searches. */
125
		void build()
126
		void build()
126
		{
127
		{
127
			assert(!is_built);
128
			assert(!is_built);
128
			nodes.resize(init_nodes.size()+1);
129
			nodes.resize(init_nodes.size()+1);
129
			if(init_nodes.size() > 0)	
130
			if(init_nodes.size() > 0)	
130
				optimize(1,0,init_nodes.size(),0);
131
				optimize(1,0,init_nodes.size(),0);
131
			NodeVecType v(0);
132
			NodeVecType v(0);
132
			init_nodes.swap(v);
133
			init_nodes.swap(v);
133
			is_built = true;
134
			is_built = true;
134
		}
135
		}
135
 
136
 
136
		/** Find the key value pair closest to the key given as first 
137
		/** Find the key value pair closest to the key given as first 
137
				argument. The second argument is the maximum search distance.
138
				argument. The second argument is the maximum search distance.
138
				The final two arguments contain the closest key and its 
139
				The final two arguments contain the closest key and its 
139
				associated value upon return. */
140
				associated value upon return. */
140
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
141
		bool closest_point(const KeyT& p, float& dist, KeyT&k, ValT&v) const
141
		{
142
		{
142
			assert(is_built);
143
			assert(is_built);
143
			float max_sq_dist = CGLA::sqr(dist);
144
			float max_sq_dist = CGLA::sqr(dist);
144
			if(int n = closest_point_priv(1, p, max_sq_dist))
145
			if(int n = closest_point_priv(1, p, max_sq_dist))
145
				{
146
				{
146
					k = nodes[n].key;
147
					k = nodes[n].key;
147
					v = nodes[n].val;
148
					v = nodes[n].val;
148
					dist = std::sqrt(max_sq_dist);
149
					dist = std::sqrt(max_sq_dist);
149
					return true;
150
					return true;
150
				}
151
				}
151
			return false;
152
			return false;
152
		}
153
		}
153
 
154
 
154
		/** Find all the elements within a given radius (second argument) of
155
		/** Find all the elements within a given radius (second argument) of
155
				the key (first argument). The key value pairs inside the sphere are
156
				the key (first argument). The key value pairs inside the sphere are
156
				returned in a pair of vectors passed as the two last arguments. */
157
				returned in a pair of vectors passed as the two last arguments. */
157
		int in_sphere(const KeyType& p, 
158
		int in_sphere(const KeyType& p, 
158
									float dist,
159
									float dist,
159
									std::vector<KeyT>& keys,
160
									std::vector<KeyT>& keys,
160
									std::vector<ValT>& vals) const
161
									std::vector<ValT>& vals) const
161
		{
162
		{
162
			assert(is_built);
163
			assert(is_built);
163
			float max_sq_dist = CGLA::sqr(dist);
164
			float max_sq_dist = CGLA::sqr(dist);
164
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
165
			in_sphere_priv(1,p,max_sq_dist,keys,vals);
165
			return keys.size();
166
			return keys.size();
166
		}
167
		}
167
		
168
		
168
 
169
 
169
	};
170
	};
170
 
171
 
171
	template<class KeyT, class ValT>
172
	template<class KeyT, class ValT>
172
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
173
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
173
																	int kvec_end) const 
174
																	int kvec_end) const 
174
	{
175
	{
175
		KeyType vmin = init_nodes[kvec_beg].key;
176
		KeyType vmin = init_nodes[kvec_beg].key;
176
		KeyType vmax = init_nodes[kvec_beg].key;
177
		KeyType vmax = init_nodes[kvec_beg].key;
177
		for(int i=kvec_beg;i<kvec_end;i++)
178
		for(int i=kvec_beg;i<kvec_end;i++)
178
			{
179
			{
179
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
180
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
180
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
181
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
181
			}
182
			}
182
		int od=0;
183
		int od=0;
183
		KeyType ave_v = vmax-vmin;
184
		KeyType ave_v = vmax-vmin;
184
		for(int i=1;i<KeyType::get_dim();i++)
185
		for(int i=1;i<KeyType::get_dim();i++)
185
			if(ave_v[i]>ave_v[od]) od = i;
186
			if(ave_v[i]>ave_v[od]) od = i;
186
		return od;
187
		return od;
187
	} 
188
	} 
188
 
189
 
189
	template<class KeyT, class ValT>
190
	template<class KeyT, class ValT>
190
	void KDTree<KeyT,ValT>::optimize(int cur,
191
	void KDTree<KeyT,ValT>::optimize(int cur,
191
																	 int kvec_beg,  
192
																	 int kvec_beg,  
192
																	 int kvec_end,  
193
																	 int kvec_end,  
193
																	 int level)
194
																	 int level)
194
	{
195
	{
195
		// Assert that we are not inserting beyond capacity.
196
		// Assert that we are not inserting beyond capacity.
196
		assert(cur < nodes.size());
197
		assert(cur < nodes.size());
197
 
198
 
198
		// If there is just a single element, we simply insert.
199
		// If there is just a single element, we simply insert.
199
		if(kvec_beg+1==kvec_end) 
200
		if(kvec_beg+1==kvec_end) 
200
			{
201
			{
201
				max_depth  = std::max(level,max_depth);
202
				max_depth  = std::max(level,max_depth);
202
				nodes[cur] = init_nodes[kvec_beg];
203
				nodes[cur] = init_nodes[kvec_beg];
203
				nodes[cur].dsc = -1;
204
				nodes[cur].dsc = -1;
204
				return;
205
				return;
205
			}
206
			}
206
	
207
	
207
		// Find the axis that best separates the data.
208
		// Find the axis that best separates the data.
208
		int disc = opt_disc(kvec_beg, kvec_end);
209
		int disc = opt_disc(kvec_beg, kvec_end);
209
 
210
 
210
		// Compute the median element. See my document on how to do this
211
		// Compute the median element. See my document on how to do this
211
		// www.imm.dtu.dk/~jab/publications.html
212
		// www.imm.dtu.dk/~jab/publications.html
212
		int N = kvec_end-kvec_beg;
213
		int N = kvec_end-kvec_beg;
213
		int M = 1<< (CGLA::two_to_what_power(N));
214
		int M = 1<< (CGLA::two_to_what_power(N));
214
		int R = N-(M-1);
215
		int R = N-(M-1);
215
		int left_size  = (M-2)/2;
216
		int left_size  = (M-2)/2;
216
		int right_size = (M-2)/2;
217
		int right_size = (M-2)/2;
217
		if(R < M/2)
218
		if(R < M/2)
218
			{
219
			{
219
				left_size += R;
220
				left_size += R;
220
			}
221
			}
221
		else
222
		else
222
			{
223
			{
223
				left_size += M/2;
224
				left_size += M/2;
224
				right_size += R-M/2;
225
				right_size += R-M/2;
225
			}
226
			}
226
 
227
 
227
		int median = kvec_beg + left_size;
228
		int median = kvec_beg + left_size;
228
 
229
 
229
		// Sort elements but use nth_element (which is cheaper) than
230
		// Sort elements but use nth_element (which is cheaper) than
230
		// a sorting algorithm. All elements to the left of the median
231
		// a sorting algorithm. All elements to the left of the median
231
		// will be smaller than or equal the median. All elements to the right
232
		// will be smaller than or equal the median. All elements to the right
232
		// will be greater than or equal to the median.
233
		// will be greater than or equal to the median.
233
		const Comp comp(disc);
234
		const Comp comp(disc);
234
		std::nth_element(&init_nodes[kvec_beg], 
235
		std::nth_element(&init_nodes[kvec_beg], 
235
										 &init_nodes[median], 
236
										 &init_nodes[median], 
236
										 &init_nodes[kvec_end], comp);
237
										 &init_nodes[kvec_end], comp);
237
 
238
 
238
		// Insert the node in the final data structure.
239
		// Insert the node in the final data structure.
239
		nodes[cur] = init_nodes[median];
240
		nodes[cur] = init_nodes[median];
240
		nodes[cur].dsc = disc;
241
		nodes[cur].dsc = disc;
241
 
242
 
242
		// Recursively build left and right tree.
243
		// Recursively build left and right tree.
243
		if(left_size>0)	
244
		if(left_size>0)	
244
			optimize(2*cur, kvec_beg, median,level+1);
245
			optimize(2*cur, kvec_beg, median,level+1);
245
		
246
		
246
		if(right_size>0) 
247
		if(right_size>0) 
247
			optimize(2*cur+1, median+1, kvec_end,level+1);
248
			optimize(2*cur+1, median+1, kvec_end,level+1);
248
	}
249
	}
249
 
250
 
250
	template<class KeyT, class ValT>
251
	template<class KeyT, class ValT>
251
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
252
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
252
																						ScalarType& dist) const
253
																						ScalarType& dist) const
253
	{
254
	{
254
		int ret_node = 0;
255
		int ret_node = 0;
255
		ScalarType this_dist = nodes[n].dist(p);
256
		ScalarType this_dist = nodes[n].dist(p);
256
 
257
 
257
		if(this_dist<dist)
258
		if(this_dist<dist)
258
			{
259
			{
259
				dist = this_dist;
260
				dist = this_dist;
260
				ret_node = n;
261
				ret_node = n;
261
			}
262
			}
262
		if(nodes[n].dsc != -1)
263
		if(nodes[n].dsc != -1)
263
			{
264
			{
264
				int dsc         = nodes[n].dsc;
265
				int dsc         = nodes[n].dsc;
265
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
266
				float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
266
				bool left_son   = Comp(dsc)(p,nodes[n].key);
267
				bool left_son   = Comp(dsc)(p,nodes[n].key);
267
 
268
 
268
				if(left_son||dsc_dist<dist)
269
				if(left_son||dsc_dist<dist)
269
					{
270
					{
270
						int left_child = 2*n;
271
						int left_child = 2*n;
271
						if(left_child < nodes.size())
272
						if(left_child < nodes.size())
272
							if(int nl=closest_point_priv(left_child, p, dist))
273
							if(int nl=closest_point_priv(left_child, p, dist))
273
								ret_node = nl;
274
								ret_node = nl;
274
					}
275
					}
275
				if(!left_son||dsc_dist<dist)
276
				if(!left_son||dsc_dist<dist)
276
					{
277
					{
277
						int right_child = 2*n+1;
278
						int right_child = 2*n+1;
278
						if(right_child < nodes.size())
279
						if(right_child < nodes.size())
279
							if(int nr=closest_point_priv(right_child, p, dist))
280
							if(int nr=closest_point_priv(right_child, p, dist))
280
								ret_node = nr;
281
								ret_node = nr;
281
					}
282
					}
282
			}
283
			}
283
		return ret_node;
284
		return ret_node;
284
	}
285
	}
285
 
286
 
286
	template<class KeyT, class ValT>
287
	template<class KeyT, class ValT>
287
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
288
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
288
																				 const KeyType& p, 
289
																				 const KeyType& p, 
289
																				 const ScalarType& dist,
290
																				 const ScalarType& dist,
290
																				 std::vector<KeyT>& keys,
291
																				 std::vector<KeyT>& keys,
291
																				 std::vector<ValT>& vals) const
292
																				 std::vector<ValT>& vals) const
292
	{
293
	{
293
		ScalarType this_dist = nodes[n].dist(p);
294
		ScalarType this_dist = nodes[n].dist(p);
294
		assert(n<nodes.size());
295
		assert(n<nodes.size());
295
		if(this_dist<dist)
296
		if(this_dist<dist)
296
			{
297
			{
297
				keys.push_back(nodes[n].key);
298
				keys.push_back(nodes[n].key);
298
				vals.push_back(nodes[n].val);
299
				vals.push_back(nodes[n].val);
299
			}
300
			}
300
		if(nodes[n].dsc != -1)
301
		if(nodes[n].dsc != -1)
301
			{
302
			{
302
				const int dsc         = nodes[n].dsc;
303
				const int dsc         = nodes[n].dsc;
303
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
304
				const float dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
304
 
305
 
305
				bool left_son = Comp(dsc)(p,nodes[n].key);
306
				bool left_son = Comp(dsc)(p,nodes[n].key);
306
 
307
 
307
				if(left_son||dsc_dist<dist)
308
				if(left_son||dsc_dist<dist)
308
					{
309
					{
309
						int left_child = 2*n;
310
						int left_child = 2*n;
310
						if(left_child < nodes.size())
311
						if(left_child < nodes.size())
311
							in_sphere_priv(left_child, p, dist, keys, vals);
312
							in_sphere_priv(left_child, p, dist, keys, vals);
312
					}
313
					}
313
				if(!left_son||dsc_dist<dist)
314
				if(!left_son||dsc_dist<dist)
314
					{
315
					{
315
						int right_child = 2*n+1;
316
						int right_child = 2*n+1;
316
						if(right_child < nodes.size())
317
						if(right_child < nodes.size())
317
							in_sphere_priv(right_child, p, dist, keys, vals);
318
							in_sphere_priv(right_child, p, dist, keys, vals);
318
					}
319
					}
319
			}
320
			}
320
	}
321
	}
321
}
322
}
322
namespace GEO = Geometry;
323
namespace GEO = Geometry;
323
 
324
 
324
#endif
325
#endif
325
 
326