/** * @file SquareMatrix.hpp * * A square matrix * * @author James Goppert */ #pragma once #include #include #include #include #include #include "math.hpp" namespace matrix { template class Matrix; template class Vector; template class SquareMatrix : public Matrix { public: SquareMatrix() : Matrix() { } SquareMatrix(const Type *data_) : Matrix(data_) { } SquareMatrix(const Matrix &other) : Matrix(other) { } // inverse alias inline SquareMatrix I() const { return inv(*this); } Vector diag() const { Vector res; const SquareMatrix &self = *this; for (size_t i = 0; i < M; i++) { res(i) = self(i, i); } return res; } Type trace() const { Type res = 0; const SquareMatrix &self = *this; for (size_t i = 0; i < M; i++) { res += self(i, i); } return res; } }; typedef SquareMatrix SquareMatrix3f; template SquareMatrix eye() { SquareMatrix m; m.setIdentity(); return m; } template SquareMatrix diag(Vector d) { SquareMatrix m; for (size_t i=0; i SquareMatrix expm(const Matrix & A, size_t order=5) { SquareMatrix res; SquareMatrix A_pow = A; res.setIdentity(); size_t i_factorial = 1; for (size_t i=1; i<=order; i++) { i_factorial *= i; res += A_pow / Type(i_factorial); A_pow *= A_pow; } return res; } /** * inverse based on LU factorization with partial pivotting */ template SquareMatrix inv(const SquareMatrix & A) { SquareMatrix L; L.setIdentity(); SquareMatrix U = A; SquareMatrix P; P.setIdentity(); //printf("A:\n"); A.print(); // for all diagonal elements for (size_t n = 0; n < M; n++) { // if diagonal is zero, swap with row below if (fabsf(U(n, n)) < 1e-8f) { //printf("trying pivot for row %d\n",n); for (size_t i = 0; i < M; i++) { if (i == n) { continue; } //printf("\ttrying row %d\n",i); if (fabsf(U(i, n)) > 1e-8f) { //printf("swapped %d\n",i); U.swapRows(i, n); P.swapRows(i, n); } } } #ifdef MATRIX_ASSERT //printf("A:\n"); A.print(); //printf("U:\n"); U.print(); //printf("P:\n"); P.print(); //fflush(stdout); ASSERT(fabsf(U(n, n)) > 1e-8f); #endif // failsafe, return zero matrix if (fabsf(U(n, n)) < 1e-8f) { SquareMatrix zero; zero.setZero(); return zero; } // for all rows below diagonal for (size_t i = (n + 1); i < M; i++) { L(i, n) = U(i, n) / U(n, n); // add i-th row and n-th row // multiplied by: -a(i,n)/a(n,n) for (size_t k = n; k < M; k++) { U(i, k) -= L(i, n) * U(n, k); } } } //printf("L:\n"); L.print(); //printf("U:\n"); U.print(); // solve LY=P*I for Y by forward subst SquareMatrix Y = P; // for all columns of Y for (size_t c = 0; c < M; c++) { // for all rows of L for (size_t i = 0; i < M; i++) { // for all columns of L for (size_t j = 0; j < i; j++) { // for all existing y // subtract the component they // contribute to the solution Y(i, c) -= L(i, j) * Y(j, c); } // divide by the factor // on current // term to be solved // Y(i,c) /= L(i,i); // but L(i,i) = 1.0 } } //printf("Y:\n"); Y.print(); // solve Ux=y for x by back subst SquareMatrix X = Y; // for all columns of X for (size_t c = 0; c < M; c++) { // for all rows of U for (size_t k = 0; k < M; k++) { // have to go in reverse order size_t i = M - 1 - k; // for all columns of U for (size_t j = i + 1; j < M; j++) { // for all existing x // subtract the component they // contribute to the solution X(i, c) -= U(i, j) * X(j, c); } // divide by the factor // on current // term to be solved X(i, c) /= U(i, i); } } //printf("X:\n"); X.print(); return X; } typedef SquareMatrix Matrix3f; } // namespace matrix /* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */