You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

146 lines
3.8 KiB

/**
* @file LeastSquaresSolver.hpp
*
* Least Squares Solver using QR householder decomposition.
* It calculates x for Ax = b.
* A = Q*R
* where R is an upper triangular matrix.
*
* R*x = Q^T*b
* This is efficiently solved for x because of the upper triangular property of R.
*
* @author Bart Slinger <bartslinger@gmail.com>
*/
#pragma once
#include "math.hpp"
namespace matrix {
template<typename Type, size_t M, size_t N>
class LeastSquaresSolver
{
public:
/**
* @brief Class calculates QR decomposition which can be used for linear
* least squares
* @param A Matrix of size MxN
*
* Initialize the class with a MxN matrix. The constructor starts the
* QR decomposition. This class does not check the rank of the matrix.
* The user needs to make sure that rank(A) = N and M >= N.
*/
LeastSquaresSolver(const Matrix<Type, M, N> &A)
{
static_assert(M >= N, "Matrix dimension should be M >= N");
// Copy contentents of matrix A
_A = A;
for (size_t j = 0; j < N; j++) {
Type normx = Type(0);
for (size_t i = j; i < M; i++) {
normx += _A(i,j) * _A(i,j);
}
normx = sqrt(normx);
Type s = _A(j,j) > 0 ? Type(-1) : Type(1);
Type u1 = _A(j,j) - s*normx;
// prevent divide by zero
// also covers u1. normx is never negative
if (normx < Type(1e-8)) {
break;
}
Type w[M] = {};
w[0] = Type(1);
for (size_t i = j+1; i < M; i++) {
w[i-j] = _A(i,j) / u1;
_A(i,j) = w[i-j];
}
_A(j,j) = s*normx;
_tau(j) = -s*u1/normx;
for (size_t k = j+1; k < N; k++) {
Type tmp = Type(0);
for (size_t i = j; i < M; i++) {
tmp += w[i-j] * _A(i,k);
}
for (size_t i = j; i < M; i++) {
_A(i,k) -= _tau(j) * w[i-j] * tmp;
}
}
}
}
/**
* @brief qtb Calculate Q^T * b
* @param b
* @return Q^T*b
*
* This function calculates Q^T * b. This is useful for the solver
* because R*x = Q^T*b.
*/
Vector<Type, M> qtb(const Vector<Type, M> &b) {
Vector<Type, M> qtbv = b;
for (size_t j = 0; j < N; j++) {
Type w[M];
w[0] = Type(1);
// fill vector w
for (size_t i = j+1; i < M; i++) {
w[i-j] = _A(i,j);
}
Type tmp = Type(0);
for (size_t i = j; i < M; i++) {
tmp += w[i-j] * qtbv(i);
}
for (size_t i = j; i < M; i++) {
qtbv(i) -= _tau(j) * w[i-j] * tmp;
}
}
return qtbv;
}
/**
* @brief Solve Ax=b for x
* @param b
* @return Vector x
*
* Find x in the equation Ax = b.
* A is provided in the initializer of the class.
*/
Vector<Type, N> solve(const Vector<Type, M> &b) {
Vector<Type, M> qtbv = qtb(b);
Vector<Type, N> x;
// size_t is unsigned and wraps i = 0 - 1 to i > N
for (size_t i = N - 1; i < N; i--) {
printf("i %d\n", static_cast<int>(i));
x(i) = qtbv(i);
for (size_t r = i+1; r < N; r++) {
x(i) -= _A(i,r) * x(r);
}
// divide by zero, return vector of zeros
if (isEqualF(_A(i,i), Type(0), Type(1e-8))) {
for (size_t z = 0; z < N; z++) {
x(z) = Type(0);
}
break;
}
x(i) /= _A(i,i);
}
return x;
}
private:
Matrix<Type, M, N> _A;
Vector<Type, N> _tau;
};
} // namespace matrix
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */