Subversion Repositories gelsvn

Rev

Rev 10 | Go to most recent revision | Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 bj 1
#ifndef __CGLA_ARITHMATFLOAT_H__
2
#define __CGLA_ARITHMATFLOAT_H__
3
 
4
#include <vector>
5
#include <iostream>
6
 
7
#include "CGLA.h"
8
 
9
 
10
namespace CGLA {
11
 
12
	/** Basic class template for matrices.
13
 
14
	In this template a matrix is defined as an array of vectors. This may
15
	not in all cases be the most efficient but it has the advantage that 
16
	it is possible to use the double subscripting notation:
17
 
18
	T x = m[i][j]
19
 
20
	This template should be used through inheritance just like the 
21
	vector template */
22
	template <class VVT, class HVT, class MT, int ROWS>
23
	class ArithMatFloat
24
	{ 
25
#define for_all_i(expr) for(int i=0;i<ROWS;i++) {expr;}
26
 
27
	public:
28
 
29
		/// Horizontal vector type
30
		typedef HVT HVectorType;
31
 
32
		/// Vertical vector type
33
		typedef VVT VVectorType;
34
 
35
		/// The type of a matrix element
36
		typedef typename HVT::ScalarType ScalarType;
37
 
38
	protected:
39
 
40
		/// The actual contents of the matrix.
41
		HVT data[ROWS];
42
 
43
	protected:
44
 
45
		/// Construct 0 matrix
46
		ArithMatFloat() 
47
		{
48
			for_all_i(data[i]=HVT(ScalarType(0)));
49
		}
50
 
51
		/// Construct a matrix where all entries are the same.
52
		explicit ArithMatFloat(ScalarType x)
53
		{
54
			for_all_i(data[i] = HVT(x));
55
		}
56
 
57
		/// Construct a matrix where all rows are the same.
58
		explicit ArithMatFloat(HVT _a)
59
		{
60
			for_all_i(data[i] = _a);
61
		}
62
 
63
 
64
		/// Construct a matrix with two rows.
65
		ArithMatFloat(HVT _a, HVT _b)
66
		{
67
			data[0] = _a;
68
			data[1] = _b;
69
		}
70
 
71
		/// Construct a matrix with three rows.
72
		ArithMatFloat(HVT _a, HVT _b, HVT _c)
73
		{
74
			data[0] = _a;
75
			data[1] = _b;
76
			data[2] = _c;
77
		}
78
 
79
		/// Construct a matrix with four rows.
80
		ArithMatFloat(HVT _a, HVT _b, HVT _c, HVT _d)
81
		{
82
			data[0] = _a;
83
			data[1] = _b;
84
			data[2] = _c;
85
			data[3] = _d;
86
		}
87
 
88
	public:
89
 
90
		/// Get vertical dimension of matrix 
91
		static int get_v_dim() {return VVT::get_dim();}
92
 
93
		/// Get horizontal dimension of matrix
94
		static int get_h_dim() {return HVT::get_dim();}
95
 
96
 
97
		/** Get const pointer to data array.
98
				This function may be useful when interfacing with some other API 
99
				such as OpenGL (TM). */
100
		const ScalarType* get() const 
101
		{
102
			return data[0].get();
103
		}
104
 
105
		/** Get pointer to data array.
106
				This function may be useful when interfacing with some other API 
107
				such as OpenGL (TM). */
108
		ScalarType* get()
109
		{
110
			return data[0].get();
111
		}
112
 
113
		/** Set values by passing an array to the matrix.
114
				The values should be ordered like [[row][row]...[row]] */
115
		void set(const ScalarType* sa) 
116
		{
117
			memcpy(get(), sa, sizeof(ScalarType)*get_h_dim()*get_v_dim());
118
		}
119
 
120
		/// Construct a matrix from an array of scalar values.
121
		explicit ArithMatFloat(const ScalarType* sa) 
122
		{
123
			set(sa);
124
		}
125
 
126
		/// Assign the rows of a 2D matrix.
127
		void set(HVT _a, HVT _b)
128
		{
129
			assert(ROWS==2);
130
			data[0] = _a;
131
			data[1] = _b;
132
		}
133
 
134
		/// Assign the rows of a 3D matrix.
135
		void set(HVT _a, HVT _b, HVT _c)
136
		{
137
			assert(ROWS==3);
138
			data[0] = _a;
139
			data[1] = _b;
140
			data[2] = _c;
141
		}
142
 
143
		/// Assign the rows of a 4D matrix.
144
		void set(HVT _a, HVT _b, HVT _c, HVT _d)
145
		{
146
			assert(ROWS==4);
147
			data[0] = _a;
148
			data[1] = _b;
149
			data[2] = _c;
150
			data[3] = _d;
151
		}
152
 
153
 
154
		//----------------------------------------------------------------------
155
		// index operators
156
		//----------------------------------------------------------------------
157
 
158
		/// Const index operator. Returns i'th row of matrix.
159
		const HVT& operator [] ( int i ) const
160
		{
161
			assert(i<ROWS);
162
			return data[i];
163
		}
164
 
165
		/// Non-const index operator. Returns i'th row of matrix.
166
		HVT& operator [] ( int i ) 
167
		{
168
			assert(i<ROWS);
169
			return data[i];
170
		}
171
 
172
		//----------------------------------------------------------------------
173
 
174
		/// Equality operator. 
175
		bool operator==(const MT& v) const 
176
		{
177
			for_all_i(if (data[i] != v[i]) return false)
178
				return true;
179
		}
180
 
181
		/// Inequality operator.
182
		bool operator!=(const MT& v) const 
183
		{
184
			return !(*this==v);
185
		}
186
 
187
		//----------------------------------------------------------------------
188
 
189
		/// Multiply scalar onto matrix. All entries are multiplied by scalar.
190
		const MT operator * (ScalarType k) const
191
		{
192
			MT v_new;
193
			for_all_i(v_new[i] = data[i] * k);
194
			return v_new;
195
		}
196
 
197
		/// Divide all entries in matrix by scalar.
198
		const MT operator / (ScalarType k) const
199
		{
200
			MT v_new;
201
			for_all_i(v_new[i] = data[i] / k);
202
			return v_new;      
203
		}
204
 
205
		/// Assignment multiplication of matrix by scalar.
206
		const MT& operator *=(ScalarType k) 
207
			{
208
				for_all_i(data[i] *= k); 
209
				return static_cast<const MT&>(*this);
210
			}
211
 
212
		/// Assignment division of matrix by scalar.
213
		const MT& operator /=(ScalarType k) 
214
			{ 
215
				for_all_i(data[i] /= k); 
216
				return static_cast<const MT&>(*this);
217
			}
218
 
219
		//----------------------------------------------------------------------
220
 
221
		/// Add two matrices. 
222
		const MT operator + (const MT& m1) const
223
		{
224
			MT v_new;
225
			for_all_i(v_new[i] = data[i] + m1[i]);
226
			return v_new;
227
		}
228
 
229
		/// Subtract two matrices.
230
		const MT operator - (const MT& m1) const
231
		{
232
			MT v_new;
233
			for_all_i(v_new[i] = data[i] - m1[i]);
234
			return v_new;
235
		}
236
 
237
		/// Assigment addition of matrices.
238
		const MT& operator +=(const MT& v) 
239
			{
240
				for_all_i(data[i] += v[i]); 
241
				return static_cast<const MT&>(*this);
242
			}
243
 
244
		/// Assigment subtraction of matrices.
245
		const MT& operator -=(const MT& v) 
246
			{
247
				for_all_i(data[i] -= v[i]); 
248
				return static_cast<const MT&>(*this);
249
			}
250
 
251
		//----------------------------------------------------------------------
252
 
253
		/// Negate matrix.
254
		const MT operator - () const
255
		{
256
			MT v_new;
257
			for_all_i(v_new[i] = - data[i]);
258
			return v_new;
259
		}
260
 
261
#undef for_all_i  
262
 
263
	};
264
 
265
	/// Multiply scalar onto matrix
266
	template <class VVT, class HVT, class MT, int ROWS>
267
	inline const MT operator * (double k, const ArithMatFloat<VVT,HVT,MT,ROWS>& v) 
268
	{
269
		return v * k;
270
	}
271
 
272
	/// Multiply scalar onto matrix
273
	template <class VVT, class HVT, class MT, int ROWS>
274
	inline const MT operator * (float k, const ArithMatFloat<VVT,HVT,MT,ROWS>& v) 
275
	{
276
		return v * k;
277
	}
278
 
279
	/// Multiply scalar onto matrix
280
	template <class VVT, class HVT, class MT, int ROWS>
281
	inline const MT operator * (int k, const ArithMatFloat<VVT,HVT,MT,ROWS>& v) 
282
	{
283
		return v * k;
284
	}
285
 
286
	/// Multiply vector onto matrix 
287
	template <class VVT, class HVT, class MT, int ROWS>
288
	inline VVT operator*(const ArithMatFloat<VVT,HVT,MT,ROWS>& m,const HVT& v) 
289
	{
290
		VVT v2;
291
		for(int i=0;i<ROWS;i++) v2[i] = dot(m[i], v);
292
		return v2;
293
	}
294
 
295
 
296
#ifndef WIN32
297
	/** Multiply two arbitrary matrices. 
298
			In principle, this function could return a matrix, but in general
299
			the new matrix will be of a type that is different from either of
300
			the two matrices that are multiplied together. We do not want to 
301
			return an ArithMatFloat - so it seems best to let the return value be
302
			a reference arg.
303
 
304
			This template can only be instantiated if the dimensions of the
305
			matrices match -- i.e. if the multiplication can actually be
306
			carried out. This is more type safe than the win32 version below.
307
	*/
308
 
309
	template <class VVT, class HVT, 
310
						class HV1T, class VV2T,
311
						class MT1, class MT2, class MT,
312
						int ROWS1, int ROWS2>
313
	inline void mul(const ArithMatFloat<VVT,HV1T,MT1,ROWS1>& m1,
314
									const ArithMatFloat<VV2T,HVT,MT2,ROWS2>& m2,
315
									ArithMatFloat<VVT,HVT,MT,ROWS1>& m)
316
	{
317
		int cols = ArithMatFloat<VVT,HVT,MT,ROWS1>::get_h_dim();
318
		for(int i=0;i<ROWS1;i++)
319
			for(int j=0;j<cols;j++)
320
				for(int k=0;k<ROWS2;k++)
321
					m[i][j] += m1[i][k] * m2[k][j]; 
322
	}
323
 
324
 
325
	/** Transpose. See the discussion on mul if you are curious as to why
326
			I don't simply return the transpose. */
327
	template <class VVT, class HVT, class M1T, class M2T, int ROWS, int COLS>
328
	inline void transpose(const ArithMatFloat<VVT,HVT,M1T,ROWS>& m,
329
												ArithMatFloat<HVT,VVT,M2T,COLS>& m_new)
330
	{
331
		for(int i=0;i<M2T::get_v_dim();i++)
332
			for(int j=0;j<M2T::get_h_dim();j++)
333
				m_new[i][j] = m[j][i];
334
	}
335
 
336
#else
337
 
338
	//----------------- win32 -------------------------------
339
	// Visual studio is not good at deducing the args. to these template functions.
340
	// This means that you can call the two functions below with 
341
	// matrices of wrong dimension.
342
 
343
	template <class M1, class M2, class M>
344
	inline void mul(const M1& m1, const M2& m2, M& m)
345
	{
346
		int cols = M::get_h_dim();
347
		int rows1 = M1::get_v_dim();
348
		int rows2 = M2::get_v_dim();
349
 
350
		for(int i=0;i<rows1;i++)
351
			for(int j=0;j<cols;j++)
352
				for(int k=0;k<rows2;k++)
353
					m[i][j] += m1[i][k] * m2[k][j];
354
	}
355
 
356
 
357
	/** Transpose. See the discussion on mul if you are curious as to why
358
			I don't simply return the transpose. */
359
	template <class M1, class M2>
360
	inline void transpose(const M1& m1, M2& m2)
361
	{
362
		for(int i=0;i<M2::get_v_dim();i++)
363
			for(int j=0;j<M2::get_h_dim();j++)
364
				m2[i][j] = m1[j][i];
365
	}
366
 
367
#endif
368
 
369
	/** Compute the outer product of a and b: a * transpose(b). This is 
370
			a matrix with a::rows and b::columns. */
371
 	template <class VVT, class HVT, class MT, int ROWS>
372
	void outer_product(const VVT& a, const HVT& b, 
373
										 ArithMatFloat<VVT,HVT,MT,ROWS>& m)
374
	{
375
		int R = VVT::get_dim();
376
		int C = HVT::get_dim();
377
		for(int i=0;i<R;++i)
378
			for(int j=0;j<C;++j)
379
				{
380
					m[i][j] = a[i] * b[j];
381
				}
382
	}
383
 
384
	/** Copy a matrix to another matrix, cell by cell.
385
			This conversion that takes a const matrix as first argument
386
			(source) and a non-const matrix as second argument
387
			(destination). The contents of the first matrix is simply copied
388
			to the second matrix. 
389
 
390
			However, if the first matrix is	larger than the second,
391
			the cells outside the range of the destination are simply not
392
			copied. If the destination is larger, the cells outside the 
393
			range of the source matrix are not touched.
394
 
395
			An obvious use of this function is to copy a 3x3 rotation matrix
396
			into a 4x4 transformation matrix.
397
	*/
398
 
399
	template <class M1, class M2>
400
	void copy_matrix(const M1& inmat, M2& outmat)
401
		{
402
			const int R = s_min(inmat.get_v_dim(), outmat.get_v_dim());
403
			const int C = s_min(inmat.get_h_dim(), outmat.get_h_dim());
404
			for(int i=0;i<R;++i)
405
				for(int j=0;j<C;++j)
406
					outmat[i][j] = inmat[i][j];
407
		}
408
 
409
	/** Put to operator */
410
	template <class VVT, class HVT, class MT, int ROWS>
411
	inline std::ostream& 
412
	operator<<(std::ostream&os, const ArithMatFloat<VVT,HVT,MT,ROWS>& m)
413
	{
414
		os << "[\n";
415
		for(int i=0;i<ROWS;i++) os << "  " << m[i] << "\n";
416
		os << "]\n";
417
		return os;
418
	}
419
 
420
	/** Get from operator */
421
	template <class VVT, class HVT, class MT, int ROWS>
422
	inline std::istream& operator>>(std::istream&is, 
423
																	const ArithMatFloat<VVT,HVT,MT,ROWS>& m)
424
	{
425
		for(int i=0;i<ROWS;i++) is>>m[i];
426
		return is;
427
	}
428
 
429
 
430
 
431
 
432
}
433
#endif