Subversion Repositories gelsvn

Rev

Rev 443 | Go to most recent revision | Only display areas with differences | Ignore whitespace | Details | Blame | Last modification | View Log | RSS feed

Rev 443 Rev 595
-
 
1
/* ----------------------------------------------------------------------- *
-
 
2
 * This file is part of GEL, http://www.imm.dtu.dk/GEL
-
 
3
 * Copyright (C) the authors and DTU Informatics
-
 
4
 * For license and list of authors, see ../../doc/intro.pdf
-
 
5
 * ----------------------------------------------------------------------- */
-
 
6
 
-
 
7
/**
-
 
8
 * @file KDTree.h
-
 
9
 * @brief KD Tree implementation based on a binary heap.
-
 
10
 */
-
 
11
 
1
#ifndef __GEOMETRY_KDTREE_H
12
#ifndef __GEOMETRY_KDTREE_H
2
#define __GEOMETRY_KDTREE_H
13
#define __GEOMETRY_KDTREE_H
3
 
14
 
4
#include <cmath>
15
#include <cmath>
5
#include <iostream>
16
#include <iostream>
6
#include <vector>
17
#include <vector>
7
#include <algorithm>
18
#include <algorithm>
8
#include "CGLA/CGLA.h"
19
#include "CGLA/CGLA.h"
9
#include "CGLA/ArithVec.h"
20
#include "CGLA/ArithVec.h"
10
 
21
 
11
#if (_MSC_VER >= 1200)
22
#if (_MSC_VER >= 1200)
12
#pragma warning (push)
23
#pragma warning (push)
13
#pragma warning (disable: 4018)
24
#pragma warning (disable: 4018)
14
#endif
25
#endif
15
 
26
 
16
namespace Geometry
27
namespace Geometry
17
{
28
{
18
	/** \brief A classic K-D tree. 
29
	/** \brief A classic K-D tree. 
19
 
30
 
20
			A K-D tree is a good data structure for storing points in space
31
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
32
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
33
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
34
	template<class KeyT, class ValT>
24
	class KDTree
35
	class KDTree
25
	{
36
	{
26
		typedef typename KeyT::ScalarType ScalarType;
37
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
38
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
39
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
40
		typedef std::vector<ValT> ValVectorType;
30
	
41
	
31
		/// KDNode struct represents node in KD tree
42
		/// KDNode struct represents node in KD tree
32
		struct KDNode
43
		struct KDNode
33
		{
44
		{
34
			KeyT key;
45
			KeyT key;
35
			ValT val;
46
			ValT val;
36
			short dsc;
47
			short dsc;
37
 
48
 
38
			KDNode(): dsc(0) {}
49
			KDNode(): dsc(0) {}
39
 
50
 
40
			KDNode(const KeyT& _key, const ValT& _val):
51
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
52
				key(_key), val(_val), dsc(-1) {}
42
 
53
 
43
			ScalarType dist(const KeyType& p) const 
54
			ScalarType dist(const KeyType& p) const 
44
			{
55
			{
45
				KeyType dist_vec = p;
56
				KeyType dist_vec = p;
46
				dist_vec  -= key;
57
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
58
				return dot(dist_vec, dist_vec);
48
			}
59
			}
49
		};
60
		};
50
 
61
 
51
		typedef std::vector<KDNode> NodeVecType;
62
		typedef std::vector<KDNode> NodeVecType;
52
		bool is_built;
63
		bool is_built;
53
		NodeVecType init_nodes;
64
		NodeVecType init_nodes;
54
		NodeVecType nodes;
65
		NodeVecType nodes;
55
 
66
 
56
		/** Comp is a class used for comparing two keys. Comp is constructed
67
		/** Comp is a class used for comparing two keys. Comp is constructed
57
				with the discriminator - i.e. the coordinate of the key that is used
68
				with the discriminator - i.e. the coordinate of the key that is used
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
69
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
		class Comp
70
		class Comp
60
		{
71
		{
61
			const int dsc;
72
			const int dsc;
62
		public:
73
		public:
63
			Comp(int _dsc): dsc(_dsc) {}
74
			Comp(int _dsc): dsc(_dsc) {}
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
75
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			{
76
			{
66
				int dim=KeyType::get_dim();
77
				int dim=KeyType::get_dim();
67
				for(int i=0;i<dim;i++)
78
				for(int i=0;i<dim;i++)
68
					{
79
					{
69
						int j=(dsc+i)%dim;
80
						int j=(dsc+i)%dim;
70
						if(k0[j]<k1[j])
81
						if(k0[j]<k1[j])
71
							return true;
82
							return true;
72
						if(k0[j]>k1[j])
83
						if(k0[j]>k1[j])
73
							return false;
84
							return false;
74
					}
85
					}
75
				return false;
86
				return false;
76
			}
87
			}
77
 
88
 
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
89
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			{
90
			{
80
				return (*this)(k0.key,k1.key);
91
				return (*this)(k0.key,k1.key);
81
			}
92
			}
82
		};
93
		};
83
 
94
 
84
 
95
 
85
		/** Passed a vector of keys, this function will construct an optimal tree.
96
		/** Passed a vector of keys, this function will construct an optimal tree.
86
				It is called recursively */
97
				It is called recursively */
87
		void optimize(int, int, int);
98
		void optimize(int, int, int);
88
 
99
 
89
		/** Finde nearest neighbour. */
100
		/** Finde nearest neighbour. */
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
101
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
91
 
102
 
92
							
103
							
93
		void in_sphere_priv(int n, 
104
		void in_sphere_priv(int n, 
94
												const KeyType& p, 
105
												const KeyType& p, 
95
												const ScalarType& dist,
106
												const ScalarType& dist,
96
												std::vector<KeyT>& keys,
107
												std::vector<KeyT>& keys,
97
												std::vector<ValT>& vals) const;
108
												std::vector<ValT>& vals) const;
98
	
109
	
99
		/** Finds the optimal discriminator. There are more ways, but this 
110
		/** Finds the optimal discriminator. There are more ways, but this 
100
				function traverses the vector and finds out what dimension has
111
				function traverses the vector and finds out what dimension has
101
				the greatest difference between min and max element. That dimension
112
				the greatest difference between min and max element. That dimension
102
				is used for discriminator */
113
				is used for discriminator */
103
		int opt_disc(int,int) const;
114
		int opt_disc(int,int) const;
104
	
115
	
105
	public:
116
	public:
106
 
117
 
107
		/** Build tree from vector of keys passed as argument. */
118
		/** Build tree from vector of keys passed as argument. */
108
		KDTree(): is_built(false), init_nodes(1) {}
119
		KDTree(): is_built(false), init_nodes(1) {}
109
 
120
 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
121
		/** Insert a key value pair into the tree. Note that the tree needs to 
111
				be built - by calling the build function - before you can search. */
122
				be built - by calling the build function - before you can search. */
112
		void insert(const KeyT& key, const ValT& val)
123
		void insert(const KeyT& key, const ValT& val)
113
		{
124
		{
114
				if(is_built)
125
				if(is_built)
115
				{
126
				{
116
						assert(init_nodes.size()==1);
127
						assert(init_nodes.size()==1);
117
						init_nodes.swap(nodes);
128
						init_nodes.swap(nodes);
118
						is_built=false;
129
						is_built=false;
119
				}
130
				}
120
				init_nodes.push_back(KDNode(key,val));
131
				init_nodes.push_back(KDNode(key,val));
121
		}
132
		}
122
 
133
 
123
		/** Build the tree. After this function has been called, it is no longer 
134
		/** Build the tree. After this function has been called, it is no longer 
124
				legal to insert elements, but you can perform searches. */
135
				legal to insert elements, but you can perform searches. */
125
		void build()
136
		void build()
126
		{
137
		{
127
			assert(!is_built);
138
			assert(!is_built);
128
			nodes.resize(init_nodes.size());
139
			nodes.resize(init_nodes.size());
129
			if(init_nodes.size() > 1)	
140
			if(init_nodes.size() > 1)	
130
				optimize(1,1,init_nodes.size());
141
				optimize(1,1,init_nodes.size());
131
			NodeVecType v(1);
142
			NodeVecType v(1);
132
			init_nodes.swap(v);
143
			init_nodes.swap(v);
133
			is_built = true;
144
			is_built = true;
134
		}
145
		}
135
 
146
 
136
		/** Find the key value pair closest to the key given as first 
147
		/** Find the key value pair closest to the key given as first 
137
				argument. The second argument is the maximum search distance. Upon
148
				argument. The second argument is the maximum search distance. Upon
138
				return this value is changed to the distance to the found point.
149
				return this value is changed to the distance to the found point.
139
				The final two arguments contain the closest key and its 
150
				The final two arguments contain the closest key and its 
140
				associated value upon return. */
151
				associated value upon return. */
141
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
152
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
142
		{
153
		{
143
			assert(is_built);
154
			assert(is_built);
144
			if(nodes.size()>1)
155
			if(nodes.size()>1)
145
			{
156
			{
146
					ScalarType max_sq_dist = CGLA::sqr(dist);
157
					ScalarType max_sq_dist = CGLA::sqr(dist);
147
					if(int n = closest_point_priv(1, p, max_sq_dist))
158
					if(int n = closest_point_priv(1, p, max_sq_dist))
148
					{
159
					{
149
							k = nodes[n].key;
160
							k = nodes[n].key;
150
							v = nodes[n].val;
161
							v = nodes[n].val;
151
							dist = std::sqrt(max_sq_dist);
162
							dist = std::sqrt(max_sq_dist);
152
							return true;
163
							return true;
153
					}
164
					}
154
			}
165
			}
155
			return false;
166
			return false;
156
		}
167
		}
157
 
168
 
158
		/** Find all the elements within a given radius (second argument) of
169
		/** Find all the elements within a given radius (second argument) of
159
				the key (first argument). The key value pairs inside the sphere are
170
				the key (first argument). The key value pairs inside the sphere are
160
				returned in a pair of vectors passed as the two last arguments.
171
				returned in a pair of vectors passed as the two last arguments.
161
				Note that we don't resize the two last arguments to zero - so either
172
				Note that we don't resize the two last arguments to zero - so either
162
				they should be empty vectors or you should desire appending the newly
173
				they should be empty vectors or you should desire appending the newly
163
				found elements onto these vectors.				
174
				found elements onto these vectors.				
164
		*/
175
		*/
165
		int in_sphere(const KeyType& p, 
176
		int in_sphere(const KeyType& p, 
166
									ScalarType dist,
177
									ScalarType dist,
167
									std::vector<KeyT>& keys,
178
									std::vector<KeyT>& keys,
168
									std::vector<ValT>& vals) const
179
									std::vector<ValT>& vals) const
169
		{
180
		{
170
			assert(is_built);
181
			assert(is_built);
171
			if(nodes.size()>1)
182
			if(nodes.size()>1)
172
			{
183
			{
173
					ScalarType max_sq_dist = CGLA::sqr(dist);
184
					ScalarType max_sq_dist = CGLA::sqr(dist);
174
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
185
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
175
					return keys.size();
186
					return keys.size();
176
			}
187
			}
177
			return 0;
188
			return 0;
178
		}
189
		}
179
		
190
		
180
 
191
 
181
	};
192
	};
182
 
193
 
183
	template<class KeyT, class ValT>
194
	template<class KeyT, class ValT>
184
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
195
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
185
																	int kvec_end) const 
196
																	int kvec_end) const 
186
	{
197
	{
187
		KeyType vmin = init_nodes[kvec_beg].key;
198
		KeyType vmin = init_nodes[kvec_beg].key;
188
		KeyType vmax = init_nodes[kvec_beg].key;
199
		KeyType vmax = init_nodes[kvec_beg].key;
189
		for(int i=kvec_beg;i<kvec_end;i++)
200
		for(int i=kvec_beg;i<kvec_end;i++)
190
			{
201
			{
191
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
202
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
192
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
203
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
193
			}
204
			}
194
		int od=0;
205
		int od=0;
195
		KeyType ave_v = vmax-vmin;
206
		KeyType ave_v = vmax-vmin;
196
		for(int i=1;i<KeyType::get_dim();i++)
207
		for(int i=1;i<KeyType::get_dim();i++)
197
			if(ave_v[i]>ave_v[od]) od = i;
208
			if(ave_v[i]>ave_v[od]) od = i;
198
		return od;
209
		return od;
199
	} 
210
	} 
200
 
211
 
201
	template<class KeyT, class ValT>
212
	template<class KeyT, class ValT>
202
	void KDTree<KeyT,ValT>::optimize(int cur,
213
	void KDTree<KeyT,ValT>::optimize(int cur,
203
																	 int kvec_beg,  
214
																	 int kvec_beg,  
204
																	 int kvec_end)
215
																	 int kvec_end)
205
	{
216
	{
206
		// Assert that we are not inserting beyond capacity.
217
		// Assert that we are not inserting beyond capacity.
207
		assert(cur < nodes.size());
218
		assert(cur < nodes.size());
208
 
219
 
209
		// If there is just a single element, we simply insert.
220
		// If there is just a single element, we simply insert.
210
		if(kvec_beg+1==kvec_end) 
221
		if(kvec_beg+1==kvec_end) 
211
			{
222
			{
212
				nodes[cur] = init_nodes[kvec_beg];
223
				nodes[cur] = init_nodes[kvec_beg];
213
				nodes[cur].dsc = -1;
224
				nodes[cur].dsc = -1;
214
				return;
225
				return;
215
			}
226
			}
216
	
227
	
217
		// Find the axis that best separates the data.
228
		// Find the axis that best separates the data.
218
		int disc = opt_disc(kvec_beg, kvec_end);
229
		int disc = opt_disc(kvec_beg, kvec_end);
219
 
230
 
220
		// Compute the median element. See my document on how to do this
231
		// Compute the median element. See my document on how to do this
221
		// www.imm.dtu.dk/~jab/publications.html
232
		// www.imm.dtu.dk/~jab/publications.html
222
		int N = kvec_end-kvec_beg;
233
		int N = kvec_end-kvec_beg;
223
		int M = 1<< (CGLA::two_to_what_power(N));
234
		int M = 1<< (CGLA::two_to_what_power(N));
224
		int R = N-(M-1);
235
		int R = N-(M-1);
225
		int left_size  = (M-2)/2;
236
		int left_size  = (M-2)/2;
226
		int right_size = (M-2)/2;
237
		int right_size = (M-2)/2;
227
		if(R < M/2)
238
		if(R < M/2)
228
			{
239
			{
229
				left_size += R;
240
				left_size += R;
230
			}
241
			}
231
		else
242
		else
232
			{
243
			{
233
				left_size += M/2;
244
				left_size += M/2;
234
				right_size += R-M/2;
245
				right_size += R-M/2;
235
			}
246
			}
236
 
247
 
237
		int median = kvec_beg + left_size;
248
		int median = kvec_beg + left_size;
238
 
249
 
239
		// Sort elements but use nth_element (which is cheaper) than
250
		// Sort elements but use nth_element (which is cheaper) than
240
		// a sorting algorithm. All elements to the left of the median
251
		// a sorting algorithm. All elements to the left of the median
241
		// will be smaller than or equal the median. All elements to the right
252
		// will be smaller than or equal the median. All elements to the right
242
		// will be greater than or equal to the median.
253
		// will be greater than or equal to the median.
243
		const Comp comp(disc);
254
		const Comp comp(disc);
244
		std::nth_element(&init_nodes[kvec_beg], 
255
		std::nth_element(&init_nodes[kvec_beg], 
245
										 &init_nodes[median], 
256
										 &init_nodes[median], 
246
										 &init_nodes[kvec_end], comp);
257
										 &init_nodes[kvec_end], comp);
247
 
258
 
248
		// Insert the node in the final data structure.
259
		// Insert the node in the final data structure.
249
		nodes[cur] = init_nodes[median];
260
		nodes[cur] = init_nodes[median];
250
		nodes[cur].dsc = disc;
261
		nodes[cur].dsc = disc;
251
 
262
 
252
		// Recursively build left and right tree.
263
		// Recursively build left and right tree.
253
		if(left_size>0)	
264
		if(left_size>0)	
254
			optimize(2*cur, kvec_beg, median);
265
			optimize(2*cur, kvec_beg, median);
255
		
266
		
256
		if(right_size>0) 
267
		if(right_size>0) 
257
			optimize(2*cur+1, median+1, kvec_end);
268
			optimize(2*cur+1, median+1, kvec_end);
258
	}
269
	}
259
 
270
 
260
	template<class KeyT, class ValT>
271
	template<class KeyT, class ValT>
261
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
272
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
262
																						ScalarType& dist) const
273
																						ScalarType& dist) const
263
	{
274
	{
264
		int ret_node = 0;
275
		int ret_node = 0;
265
		ScalarType this_dist = nodes[n].dist(p);
276
		ScalarType this_dist = nodes[n].dist(p);
266
 
277
 
267
		if(this_dist<dist)
278
		if(this_dist<dist)
268
			{
279
			{
269
				dist = this_dist;
280
				dist = this_dist;
270
				ret_node = n;
281
				ret_node = n;
271
			}
282
			}
272
		if(nodes[n].dsc != -1)
283
		if(nodes[n].dsc != -1)
273
			{
284
			{
274
				int dsc         = nodes[n].dsc;
285
				int dsc         = nodes[n].dsc;
275
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
286
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
276
				bool left_son   = Comp(dsc)(p,nodes[n].key);
287
				bool left_son   = Comp(dsc)(p,nodes[n].key);
277
 
288
 
278
				if(left_son||dsc_dist<dist)
289
				if(left_son||dsc_dist<dist)
279
					{
290
					{
280
						int left_child = 2*n;
291
						int left_child = 2*n;
281
						if(left_child < nodes.size())
292
						if(left_child < nodes.size())
282
							if(int nl=closest_point_priv(left_child, p, dist))
293
							if(int nl=closest_point_priv(left_child, p, dist))
283
								ret_node = nl;
294
								ret_node = nl;
284
					}
295
					}
285
				if(!left_son||dsc_dist<dist)
296
				if(!left_son||dsc_dist<dist)
286
					{
297
					{
287
						int right_child = 2*n+1;
298
						int right_child = 2*n+1;
288
						if(right_child < nodes.size())
299
						if(right_child < nodes.size())
289
							if(int nr=closest_point_priv(right_child, p, dist))
300
							if(int nr=closest_point_priv(right_child, p, dist))
290
								ret_node = nr;
301
								ret_node = nr;
291
					}
302
					}
292
			}
303
			}
293
		return ret_node;
304
		return ret_node;
294
	}
305
	}
295
 
306
 
296
	template<class KeyT, class ValT>
307
	template<class KeyT, class ValT>
297
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
308
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
298
																				 const KeyType& p, 
309
																				 const KeyType& p, 
299
																				 const ScalarType& dist,
310
																				 const ScalarType& dist,
300
																				 std::vector<KeyT>& keys,
311
																				 std::vector<KeyT>& keys,
301
																				 std::vector<ValT>& vals) const
312
																				 std::vector<ValT>& vals) const
302
	{
313
	{
303
		ScalarType this_dist = nodes[n].dist(p);
314
		ScalarType this_dist = nodes[n].dist(p);
304
		assert(n<nodes.size());
315
		assert(n<nodes.size());
305
		if(this_dist<dist)
316
		if(this_dist<dist)
306
			{
317
			{
307
				keys.push_back(nodes[n].key);
318
				keys.push_back(nodes[n].key);
308
				vals.push_back(nodes[n].val);
319
				vals.push_back(nodes[n].val);
309
			}
320
			}
310
		if(nodes[n].dsc != -1)
321
		if(nodes[n].dsc != -1)
311
			{
322
			{
312
				const int dsc         = nodes[n].dsc;
323
				const int dsc         = nodes[n].dsc;
313
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
324
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
314
 
325
 
315
				bool left_son = Comp(dsc)(p,nodes[n].key);
326
				bool left_son = Comp(dsc)(p,nodes[n].key);
316
 
327
 
317
				if(left_son||dsc_dist<dist)
328
				if(left_son||dsc_dist<dist)
318
					{
329
					{
319
						int left_child = 2*n;
330
						int left_child = 2*n;
320
						if(left_child < nodes.size())
331
						if(left_child < nodes.size())
321
							in_sphere_priv(left_child, p, dist, keys, vals);
332
							in_sphere_priv(left_child, p, dist, keys, vals);
322
					}
333
					}
323
				if(!left_son||dsc_dist<dist)
334
				if(!left_son||dsc_dist<dist)
324
					{
335
					{
325
						int right_child = 2*n+1;
336
						int right_child = 2*n+1;
326
						if(right_child < nodes.size())
337
						if(right_child < nodes.size())
327
							in_sphere_priv(right_child, p, dist, keys, vals);
338
							in_sphere_priv(right_child, p, dist, keys, vals);
328
					}
339
					}
329
			}
340
			}
330
	}
341
	}
331
}
342
}
332
namespace GEO = Geometry;
343
namespace GEO = Geometry;
333
 
344
 
334
#if (_MSC_VER >= 1200)
345
#if (_MSC_VER >= 1200)
335
#pragma warning (pop)
346
#pragma warning (pop)
336
#endif
347
#endif
337
 
348
 
338
 
349
 
339
#endif
350
#endif
340
 
351