Subversion Repositories gelsvn

Rev

Rev 324 | Rev 334 | Go to most recent revision | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
61 jab 1
#ifndef __KDTREE_H
2
#define __KDTREE_H
3
 
4
#include <cmath>
5
#include <iostream>
6
#include <vector>
7
#include <algorithm>
8
#include "CGLA/CGLA.h"
125 jab 9
#include "CGLA/ArithVec.h"
61 jab 10
 
198 bj 11
#if (_MSC_VER >= 1200)
203 jrf 12
#pragma warning (push)
198 bj 13
#pragma warning (disable: 4018)
14
#endif
15
 
61 jab 16
namespace Geometry
17
{
89 jab 18
	/** \brief A classic K-D tree. 
19
 
61 jab 20
			A K-D tree is a good data structure for storing points in space
21
			and for nearest neighbour queries. It is basically a generalized 
22
			binary tree in K dimensions. */
23
	template<class KeyT, class ValT>
24
	class KDTree
25
	{
26
		typedef typename KeyT::ScalarType ScalarType;
27
		typedef KeyT KeyType;
28
		typedef std::vector<KeyT> KeyVectorType;
29
		typedef std::vector<ValT> ValVectorType;
30
 
31
		/// KDNode struct represents node in KD tree
32
		struct KDNode
33
		{
34
			KeyT key;
35
			ValT val;
36
			short dsc;
37
 
38
			KDNode(): dsc(0) {}
39
 
40
			KDNode(const KeyT& _key, const ValT& _val):
41
				key(_key), val(_val), dsc(-1) {}
42
 
43
			ScalarType dist(const KeyType& p) const 
44
			{
45
				KeyType dist_vec = p;
46
				dist_vec  -= key;
47
				return dot(dist_vec, dist_vec);
48
			}
49
		};
50
 
51
		typedef std::vector<KDNode> NodeVecType;
330 jab 52
		bool is_built;
61 jab 53
		NodeVecType init_nodes;
54
		NodeVecType nodes;
55
 
56
		/** Comp is a class used for comparing two keys. Comp is constructed
57
				with the discriminator - i.e. the coordinate of the key that is used
58
				for comparing keys - Comp objects are passed to the sort algorithm.*/
59
		class Comp
60
		{
61
			const int dsc;
62
		public:
63
			Comp(int _dsc): dsc(_dsc) {}
64
			bool operator()(const KeyType& k0, const KeyType& k1) const
65
			{
66
				int dim=KeyType::get_dim();
67
				for(int i=0;i<dim;i++)
68
					{
69
						int j=(dsc+i)%dim;
70
						if(k0[j]<k1[j])
71
							return true;
72
						if(k0[j]>k1[j])
73
							return false;
74
					}
75
				return false;
76
			}
77
 
78
			bool operator()(const KDNode& k0, const KDNode& k1) const
79
			{
80
				return (*this)(k0.key,k1.key);
81
			}
82
		};
83
 
84
 
85
		/** Passed a vector of keys, this function will construct an optimal tree.
330 jab 86
				It is called recursively */
87
		void optimize(int, int, int);
61 jab 88
 
89
		/** Finde nearest neighbour. */
90
		int closest_point_priv(int, const KeyType&, ScalarType&) const;
91
 
92
 
93
		void in_sphere_priv(int n, 
94
												const KeyType& p, 
95
												const ScalarType& dist,
96
												std::vector<KeyT>& keys,
97
												std::vector<ValT>& vals) const;
98
 
99
		/** Finds the optimal discriminator. There are more ways, but this 
100
				function traverses the vector and finds out what dimension has
101
				the greatest difference between min and max element. That dimension
102
				is used for discriminator */
103
		int opt_disc(int,int) const;
104
 
105
	public:
106
 
107
		/** Build tree from vector of keys passed as argument. */
330 jab 108
		KDTree(): is_built(false), init_nodes(1) {}
61 jab 109
 
110
		/** Insert a key value pair into the tree. Note that the tree needs to 
111
				be built - by calling the build function - before you can search. */
112
		void insert(const KeyT& key, const ValT& val)
113
		{
330 jab 114
				if(is_built)
115
				{
116
						assert(init_nodes.size()==1);
117
						init_nodes.swap(nodes);
118
						is_built=false;
119
				}
120
				init_nodes.push_back(KDNode(key,val));
61 jab 121
		}
122
 
330 jab 123
		/** Build the tree. After this function has been called, it is no longer 
61 jab 124
				legal to insert elements, but you can perform searches. */
125
		void build()
126
		{
127
			assert(!is_built);
330 jab 128
			nodes.resize(init_nodes.size());
129
			if(init_nodes.size() > 1)	
130
				optimize(1,1,init_nodes.size());
131
			NodeVecType v(1);
61 jab 132
			init_nodes.swap(v);
133
			is_built = true;
134
		}
135
 
136
		/** Find the key value pair closest to the key given as first 
137
				argument. The second argument is the maximum search distance.
138
				The final two arguments contain the closest key and its 
139
				associated value upon return. */
324 jab 140
		bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
61 jab 141
		{
142
			assert(is_built);
330 jab 143
			if(nodes.size()>1)
144
			{
145
					ScalarType max_sq_dist = CGLA::sqr(dist);
146
					if(int n = closest_point_priv(1, p, max_sq_dist))
147
					{
148
							k = nodes[n].key;
149
							v = nodes[n].val;
150
							dist = std::sqrt(max_sq_dist);
151
							return true;
152
					}
153
			}
61 jab 154
			return false;
155
		}
156
 
157
		/** Find all the elements within a given radius (second argument) of
158
				the key (first argument). The key value pairs inside the sphere are
330 jab 159
				returned in a pair of vectors passed as the two last arguments.
160
				Note that we don't resize the two last arguments to zero - so either
161
				they should be empty vectors or you should desire appending the newly
162
				found elements onto these vectors.				
163
		*/
61 jab 164
		int in_sphere(const KeyType& p, 
324 jab 165
									ScalarType dist,
61 jab 166
									std::vector<KeyT>& keys,
167
									std::vector<ValT>& vals) const
168
		{
169
			assert(is_built);
330 jab 170
			if(nodes.size()>1)
171
			{
172
					ScalarType max_sq_dist = CGLA::sqr(dist);
173
					in_sphere_priv(1,p,max_sq_dist,keys,vals);
174
					return keys.size();
175
			}
176
			return 0;
61 jab 177
		}
178
 
179
 
180
	};
181
 
182
	template<class KeyT, class ValT>
183
	int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
184
																	int kvec_end) const 
185
	{
186
		KeyType vmin = init_nodes[kvec_beg].key;
187
		KeyType vmax = init_nodes[kvec_beg].key;
188
		for(int i=kvec_beg;i<kvec_end;i++)
189
			{
190
				vmin = CGLA::v_min(vmin,init_nodes[i].key);
191
				vmax = CGLA::v_max(vmax,init_nodes[i].key);
192
			}
193
		int od=0;
194
		KeyType ave_v = vmax-vmin;
195
		for(int i=1;i<KeyType::get_dim();i++)
196
			if(ave_v[i]>ave_v[od]) od = i;
197
		return od;
198
	} 
199
 
200
	template<class KeyT, class ValT>
201
	void KDTree<KeyT,ValT>::optimize(int cur,
202
																	 int kvec_beg,  
330 jab 203
																	 int kvec_end)
61 jab 204
	{
205
		// Assert that we are not inserting beyond capacity.
206
		assert(cur < nodes.size());
207
 
208
		// If there is just a single element, we simply insert.
209
		if(kvec_beg+1==kvec_end) 
210
			{
211
				nodes[cur] = init_nodes[kvec_beg];
212
				nodes[cur].dsc = -1;
213
				return;
214
			}
215
 
216
		// Find the axis that best separates the data.
217
		int disc = opt_disc(kvec_beg, kvec_end);
218
 
219
		// Compute the median element. See my document on how to do this
220
		// www.imm.dtu.dk/~jab/publications.html
221
		int N = kvec_end-kvec_beg;
222
		int M = 1<< (CGLA::two_to_what_power(N));
223
		int R = N-(M-1);
224
		int left_size  = (M-2)/2;
225
		int right_size = (M-2)/2;
226
		if(R < M/2)
227
			{
228
				left_size += R;
229
			}
230
		else
231
			{
232
				left_size += M/2;
233
				right_size += R-M/2;
234
			}
235
 
236
		int median = kvec_beg + left_size;
237
 
238
		// Sort elements but use nth_element (which is cheaper) than
239
		// a sorting algorithm. All elements to the left of the median
240
		// will be smaller than or equal the median. All elements to the right
241
		// will be greater than or equal to the median.
242
		const Comp comp(disc);
243
		std::nth_element(&init_nodes[kvec_beg], 
244
										 &init_nodes[median], 
245
										 &init_nodes[kvec_end], comp);
246
 
247
		// Insert the node in the final data structure.
248
		nodes[cur] = init_nodes[median];
249
		nodes[cur].dsc = disc;
250
 
251
		// Recursively build left and right tree.
252
		if(left_size>0)	
330 jab 253
			optimize(2*cur, kvec_beg, median);
61 jab 254
 
255
		if(right_size>0) 
330 jab 256
			optimize(2*cur+1, median+1, kvec_end);
61 jab 257
	}
258
 
259
	template<class KeyT, class ValT>
260
	int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
261
																						ScalarType& dist) const
262
	{
263
		int ret_node = 0;
264
		ScalarType this_dist = nodes[n].dist(p);
265
 
266
		if(this_dist<dist)
267
			{
268
				dist = this_dist;
269
				ret_node = n;
270
			}
271
		if(nodes[n].dsc != -1)
272
			{
273
				int dsc         = nodes[n].dsc;
324 jab 274
				ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
61 jab 275
				bool left_son   = Comp(dsc)(p,nodes[n].key);
276
 
277
				if(left_son||dsc_dist<dist)
278
					{
279
						int left_child = 2*n;
280
						if(left_child < nodes.size())
281
							if(int nl=closest_point_priv(left_child, p, dist))
282
								ret_node = nl;
283
					}
284
				if(!left_son||dsc_dist<dist)
285
					{
286
						int right_child = 2*n+1;
287
						if(right_child < nodes.size())
288
							if(int nr=closest_point_priv(right_child, p, dist))
289
								ret_node = nr;
290
					}
291
			}
292
		return ret_node;
293
	}
294
 
295
	template<class KeyT, class ValT>
296
	void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
297
																				 const KeyType& p, 
298
																				 const ScalarType& dist,
299
																				 std::vector<KeyT>& keys,
300
																				 std::vector<ValT>& vals) const
301
	{
302
		ScalarType this_dist = nodes[n].dist(p);
303
		assert(n<nodes.size());
304
		if(this_dist<dist)
305
			{
306
				keys.push_back(nodes[n].key);
307
				vals.push_back(nodes[n].val);
308
			}
309
		if(nodes[n].dsc != -1)
310
			{
311
				const int dsc         = nodes[n].dsc;
324 jab 312
				const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
61 jab 313
 
314
				bool left_son = Comp(dsc)(p,nodes[n].key);
315
 
316
				if(left_son||dsc_dist<dist)
317
					{
318
						int left_child = 2*n;
319
						if(left_child < nodes.size())
320
							in_sphere_priv(left_child, p, dist, keys, vals);
321
					}
322
				if(!left_son||dsc_dist<dist)
323
					{
324
						int right_child = 2*n+1;
325
						if(right_child < nodes.size())
326
							in_sphere_priv(right_child, p, dist, keys, vals);
327
					}
328
			}
329
	}
330
}
331
namespace GEO = Geometry;
332
 
198 bj 333
#if (_MSC_VER >= 1200)
203 jrf 334
#pragma warning (pop)
61 jab 335
#endif
198 bj 336
 
337
 
338
#endif