Subversion Repositories gelsvn

Rev

Rev 630 | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
595 jab 1
/* ----------------------------------------------------------------------- *
2
 * This file is part of GEL, http://www.imm.dtu.dk/GEL
3
 * Copyright (C) the authors and DTU Informatics
4
 * For license and list of authors, see ../../doc/intro.pdf
5
 * ----------------------------------------------------------------------- */
6
 
632 janba 7
/** @file ArithMatFloat.h
8
 * Abstract matrix class
9
 */
595 jab 10
 
2 bj 11
#ifndef __CGLA_ARITHMATFLOAT_H__
12
#define __CGLA_ARITHMATFLOAT_H__
13
 
14
#include <vector>
15
#include <iostream>
382 jab 16
#include <numeric>
2 bj 17
 
18
#include "CGLA.h"
19
 
20
 
632 janba 21
namespace CGLA
12 jab 22
{
632 janba 23
 
24
    /** \brief Basic class template for matrices.
25
 
26
     In this template a matrix is defined as an array of vectors. This may
27
     not in all cases be the most efficient but it has the advantage that
28
     it is possible to use the double subscripting notation:
29
 
30
     T x = m[i][j]
31
 
32
     This template should be used through inheritance just like the
33
     vector template */
34
    template <class VVT, class HVT, class MT, unsigned int ROWS>
12 jab 35
    class ArithMatFloat
632 janba 36
    {
12 jab 37
    public:
632 janba 38
 
39
        /// Horizontal vector type
40
        typedef HVT HVectorType;
41
 
42
        /// Vertical vector type
43
        typedef VVT VVectorType;
44
 
45
        /// The type of a matrix element
46
        typedef typename HVT::ScalarType ScalarType;
2 bj 47
 
12 jab 48
    protected:
632 janba 49
 
50
        /// The actual contents of the matrix.
51
        HVT data[ROWS];
52
 
12 jab 53
    protected:
632 janba 54
 
55
        /// Construct 0 matrix
56
        ArithMatFloat()
57
        {
10 jab 58
#ifndef NDEBUG
632 janba 59
            std::fill_n(data, ROWS, HVT(CGLA_INIT_VALUE));
10 jab 60
#endif
632 janba 61
        }
62
 
63
        /// Construct a matrix where all entries are the same.
64
        explicit ArithMatFloat(ScalarType x)
65
        {
66
            std::fill_n(data, ROWS, HVT(x));
67
        }
68
 
69
        /// Construct a matrix where all rows are the same.
70
        explicit ArithMatFloat(HVT _a)
71
        {
72
            std::fill_n(data, ROWS, _a);
73
        }
74
 
75
        /// Construct a matrix with two rows.
76
        ArithMatFloat(HVT _a, HVT _b)
77
        {
78
            assert(ROWS==2);
79
            data[0] = _a;
80
            data[1] = _b;
81
        }
82
 
83
        /// Construct a matrix with three rows.
84
        ArithMatFloat(HVT _a, HVT _b, HVT _c)
85
        {
86
            assert(ROWS==3);
87
            data[0] = _a;
88
            data[1] = _b;
89
            data[2] = _c;
90
        }
91
 
92
        /// Construct a matrix with four rows.
93
        ArithMatFloat(HVT _a, HVT _b, HVT _c, HVT _d)
94
        {
95
            assert(ROWS==4);
96
            data[0] = _a;
97
            data[1] = _b;
98
            data[2] = _c;
99
            data[3] = _d;
100
        }
2 bj 101
 
12 jab 102
    public:
632 janba 103
 
104
        /// Get vertical dimension of matrix
105
        static unsigned int get_v_dim() {return VVT::get_dim();}
106
 
107
        /// Get horizontal dimension of matrix
108
        static unsigned int get_h_dim() {return HVT::get_dim();}
109
 
110
 
111
        /** Get const pointer to data array.
112
         This function may be useful when interfacing with some other API
113
         such as OpenGL (TM). */
114
        const ScalarType* get() const
115
        {
116
            return data[0].get();
117
        }
118
 
119
        /** Get pointer to data array.
120
         This function may be useful when interfacing with some other API
121
         such as OpenGL (TM). */
122
        ScalarType* get()
123
        {
124
            return data[0].get();
125
        }
126
 
127
        //----------------------------------------------------------------------
128
        // index operators
129
        //----------------------------------------------------------------------
130
 
131
        /// Const index operator. Returns i'th row of matrix.
132
        const HVT& operator [] ( unsigned int i ) const
133
        {
134
            assert(i<ROWS);
135
            return data[i];
136
        }
137
 
138
        /// Non-const index operator. Returns i'th row of matrix.
139
        HVT& operator [] ( unsigned int i )
140
        {
141
            assert(i<ROWS);
142
            return data[i];
143
        }
144
 
145
        //----------------------------------------------------------------------
146
 
147
        /// Equality operator.
148
        bool operator==(const MT& v) const
149
        {
150
            return std::inner_product(data, &data[ROWS], &v[0], true,
151
                                      std::logical_and<bool>(), std::equal_to<HVT>());
152
        }
153
 
154
        /// Inequality operator.
155
        bool operator!=(const MT& v) const
156
        {
157
            return !(*this==v);
158
        }
159
 
160
        //----------------------------------------------------------------------
161
 
162
        /// Multiply scalar onto matrix. All entries are multiplied by scalar.
163
        const MT operator * (ScalarType k) const
164
        {
165
            MT v_new;
166
            std::transform(data, &data[ROWS], &v_new[0], std::bind2nd(std::multiplies<HVT>(), HVT(k)));
167
            return v_new;
168
        }
169
 
170
        /// Divide all entries in matrix by scalar.
171
        const MT operator / (ScalarType k) const
172
        {
173
            MT v_new;
174
            std::transform(data, &data[ROWS], &v_new[0], std::bind2nd(std::divides<HVT>(), HVT(k)));
175
            return v_new;
176
        }
177
 
178
        /// Assignment multiplication of matrix by scalar.
179
        const MT& operator *=(ScalarType k)
180
        {
181
            std::transform(data, &data[ROWS], data, std::bind2nd(std::multiplies<HVT>(), HVT(k)));
182
            return static_cast<const MT&>(*this);
183
        }
184
 
185
        /// Assignment division of matrix by scalar.
186
        const MT& operator /=(ScalarType k)
187
        {
188
            std::transform(data, &data[ROWS], data, std::bind2nd(std::divides<HVT>(), HVT(k)));
189
            return static_cast<const MT&>(*this);
190
        }
191
 
192
        //----------------------------------------------------------------------
193
 
194
        /// Add two matrices.
195
        const MT operator + (const MT& m1) const
196
        {
197
            MT v_new;
198
            std::transform(data, &data[ROWS], &m1[0], &v_new[0], std::plus<HVT>());
199
            return v_new;
200
        }
201
 
202
        /// Subtract two matrices.
203
        const MT operator - (const MT& m1) const
204
        {
205
            MT v_new;
206
            std::transform(data, &data[ROWS], &m1[0], &v_new[0], std::minus<HVT>());
207
            return v_new;
208
        }
209
 
210
        /// Assigment addition of matrices.
211
        const MT& operator +=(const MT& v)
212
        {
213
            std::transform(data, &data[ROWS], &v[0], data, std::plus<HVT>());
214
            return static_cast<const MT&>(*this);
215
        }
216
 
217
        /// Assigment subtraction of matrices.
218
        const MT& operator -=(const MT& v)
219
        {
220
            std::transform(data, &data[ROWS], &v[0], data, std::minus<HVT>());
221
            return static_cast<const MT&>(*this);
222
        }
223
 
224
        //----------------------------------------------------------------------
225
 
226
        /// Negate matrix.
227
        const MT operator - () const
228
        {
229
            MT v_new;
230
            std::transform(data, &data[ROWS], &v_new[0], std::negate<HVT>());
231
            return v_new;
232
        }
12 jab 233
    };
632 janba 234
 
235
    /// Multiply scalar onto matrix
236
    template <class VVT, class HVT, class MT, unsigned int ROWS>
237
    inline const MT operator * (double k, const ArithMatFloat<VVT,HVT,MT,ROWS>& v)
12 jab 238
    {
632 janba 239
        return v * k;
12 jab 240
    }
632 janba 241
 
242
    /// Multiply scalar onto matrix
243
    template <class VVT, class HVT, class MT, unsigned int ROWS>
244
    inline const MT operator * (float k, const ArithMatFloat<VVT,HVT,MT,ROWS>& v)
12 jab 245
    {
632 janba 246
        return v * k;
12 jab 247
    }
632 janba 248
 
249
    /// Multiply scalar onto matrix
250
    template <class VVT, class HVT, class MT, unsigned int ROWS>
251
    inline const MT operator * (int k, const ArithMatFloat<VVT,HVT,MT,ROWS>& v)
12 jab 252
    {
632 janba 253
        return v * k;
12 jab 254
    }
632 janba 255
 
256
    /// Multiply vector onto matrix
257
    template <class VVT, class HVT, class MT, unsigned int ROWS>
258
    inline VVT operator*(const ArithMatFloat<VVT,HVT,MT,ROWS>& m,const HVT& v)
12 jab 259
    {
632 janba 260
        VVT v2;
261
        for(unsigned int i=0;i<ROWS;i++) v2[i] = dot(m[i], v);
262
            return v2;
12 jab 263
    }
632 janba 264
 
265
 
2 bj 266
#ifndef WIN32
632 janba 267
    /** Multiply two arbitrary matrices.
268
     In principle, this function could return a matrix, but in general
269
     the new matrix will be of a type that is different from either of
270
     the two matrices that are multiplied together. We do not want to
271
     return an ArithMatFloat - so it seems best to let the return value be
272
     a reference arg.
273
 
274
     This template can only be instantiated if the dimensions of the
275
     matrices match -- i.e. if the multiplication can actually be
276
     carried out. This is more type safe than the win32 version below.
277
     */
278
 
279
    template <class VVT, class HVT,
12 jab 280
    class HV1T, class VV2T,
281
    class MT1, class MT2, class MT,
45 jab 282
    unsigned int ROWS1, unsigned int ROWS2>
12 jab 283
    inline void mul(const ArithMatFloat<VVT,HV1T,MT1,ROWS1>& m1,
632 janba 284
                    const ArithMatFloat<VV2T,HVT,MT2,ROWS2>& m2,
285
                    ArithMatFloat<VVT,HVT,MT,ROWS1>& m)
12 jab 286
    {
632 janba 287
        unsigned int cols = ArithMatFloat<VVT,HVT,MT,ROWS1>::get_h_dim();
288
        for(unsigned int i=0;i<ROWS1;i++)
289
            for(unsigned int j=0;j<cols;j++)
290
            {
291
                m[i][j] = 0;
292
                for(unsigned int k=0;k<ROWS2;k++)
293
                    m[i][j] += m1[i][k] * m2[k][j];
294
            }
12 jab 295
    }
632 janba 296
 
297
 
298
    /** Transpose. See the discussion on mul if you are curious as to why
299
     I don't simply return the transpose. */
300
    template <class VVT, class HVT, class M1T, class M2T, unsigned int ROWS, unsigned int COLS>
12 jab 301
    inline void transpose(const ArithMatFloat<VVT,HVT,M1T,ROWS>& m,
632 janba 302
                          ArithMatFloat<HVT,VVT,M2T,COLS>& m_new)
12 jab 303
    {
632 janba 304
        for(unsigned int i=0;i<M2T::get_v_dim();++i)
305
            for(unsigned int j=0;j<M2T::get_h_dim();++j)
306
                m_new[i][j] = m[j][i];
12 jab 307
    }
632 janba 308
 
2 bj 309
#else
632 janba 310
 
311
    //----------------- win32 -------------------------------
312
    // Visual studio is not good at deducing the args. to these template functions.
313
    // This means that you can call the two functions below with
314
    // matrices of wrong dimension.
315
 
316
    template <class M1, class M2, class M>
12 jab 317
    inline void mul(const M1& m1, const M2& m2, M& m)
318
    {
632 janba 319
        unsigned int cols = M::get_h_dim();
320
        unsigned int rows1 = M1::get_v_dim();
321
        unsigned int rows2 = M2::get_v_dim();
322
 
323
        for(unsigned int i=0;i<rows1;++i)
324
            for(unsigned int j=0;j<cols;++j)
325
            {
326
                m[i][j] = 0;
327
                for(unsigned int k=0;k<rows2;++k)
328
                    m[i][j] += m1[i][k] * m2[k][j];
329
            }
12 jab 330
    }
632 janba 331
 
332
 
333
    /** Transpose. See the discussion on mul if you are curious as to why
334
     I don't simply return the transpose. */
335
    template <class M1, class M2>
12 jab 336
    inline void transpose(const M1& m1, M2& m2)
337
    {
632 janba 338
        for(unsigned int i=0;i<M2::get_v_dim();++i)
339
            for(unsigned int j=0;j<M2::get_h_dim();++j)
340
                m2[i][j] = m1[j][i];
12 jab 341
    }
632 janba 342
 
2 bj 343
#endif
632 janba 344
 
345
    /** Compute the outer product of a and b: a * transpose(b). This is
346
     a matrix with a::rows and b::columns. */
347
    template <class VVT, class HVT, class MT, unsigned int ROWS>
348
    void outer_product(const VVT& a, const HVT& b,
349
                       ArithMatFloat<VVT,HVT,MT,ROWS>& m)
12 jab 350
    {
632 janba 351
        unsigned int R = VVT::get_dim();
352
        unsigned int C = HVT::get_dim();
353
        for(unsigned int i=0;i<R;++i)
354
            for(unsigned int j=0;j<C;++j)
355
            {
356
                m[i][j] = a[i] * b[j];
357
            }
12 jab 358
    }
632 janba 359
 
360
    /** Compute the outer product of a and b using an arbitrary
361
     binary operation: op(a, transpose(b)). This is
362
     a matrix with a::rows and b::columns. */
363
    template <class VVT, class HVT, class MT, int ROWS, class BinOp>
364
    void outer_product(const VVT& a, const HVT& b,
365
                       ArithMatFloat<VVT,HVT,MT,ROWS>& m, BinOp op)
46 jrf 366
    {
632 janba 367
        int R = VVT::get_dim();
368
        int C = HVT::get_dim();
369
        for(int i=0;i<R;++i)
370
            for(int j=0;j<C;++j)
371
            {
372
                m[i][j] = op(a[i], b[j]);
373
            }
46 jrf 374
    }
632 janba 375
 
376
    /** Copy a matrix to another matrix, cell by cell.
377
     This conversion that takes a const matrix as first argument
378
     (source) and a non-const matrix as second argument
379
     (destination). The contents of the first matrix is simply copied
380
     to the second matrix. 
381
 
382
     However, if the first matrix is	larger than the second,
383
     the cells outside the range of the destination are simply not
384
     copied. If the destination is larger, the cells outside the 
385
     range of the source matrix are not touched.
386
 
387
     An obvious use of this function is to copy a 3x3 rotation matrix
388
     into a 4x4 transformation matrix.
389
     */
390
 
391
    template <class M1, class M2>
12 jab 392
    void copy_matrix(const M1& inmat, M2& outmat)
393
    {
632 janba 394
        const unsigned int R = std::min(inmat.get_v_dim(), outmat.get_v_dim());
395
        const unsigned int C = std::min(inmat.get_h_dim(), outmat.get_h_dim());
396
        for(unsigned int i=0;i<R;++i)
397
            for(unsigned int j=0;j<C;++j)
398
                outmat[i][j] = inmat[i][j];
12 jab 399
    }
632 janba 400
 
401
    /** Put to operator */
402
    template <class VVT, class HVT, class MT, unsigned int ROWS>
12 jab 403
    inline std::ostream& 
404
    operator<<(std::ostream&os, const ArithMatFloat<VVT,HVT,MT,ROWS>& m)
405
    {
632 janba 406
        os << "[\n";
407
        for(unsigned int i=0;i<ROWS;i++) os << "  " << m[i] << "\n";
408
        os << "]\n";
409
        return os;
410
        }
411
 
412
        /** Get from operator */
413
        template <class VVT, class HVT, class MT, unsigned int ROWS>
414
        inline std::istream& operator>>(std::istream&is, 
415
                                        const ArithMatFloat<VVT,HVT,MT,ROWS>& m)
416
        {
417
            for(unsigned int i=0;i<ROWS;i++) is>>m[i];
418
            return is;
419
        }
420
        }
2 bj 421
#endif