|
|
|
/**
|
|
|
|
* @file LeastSquaresSolver.hpp
|
|
|
|
*
|
|
|
|
* Least Squares Solver using QR householder decomposition.
|
|
|
|
* It calculates x for Ax = b.
|
|
|
|
* A = Q*R
|
|
|
|
* where R is an upper triangular matrix.
|
|
|
|
*
|
|
|
|
* R*x = Q^T*b
|
|
|
|
* This is efficiently solved for x because of the upper triangular property of R.
|
|
|
|
*
|
|
|
|
* @author Bart Slinger <bartslinger@gmail.com>
|
|
|
|
*/
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "math.hpp"
|
|
|
|
|
|
|
|
namespace matrix {
|
|
|
|
|
|
|
|
template<typename Type, size_t M, size_t N>
|
|
|
|
class LeastSquaresSolver
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Class calculates QR decomposition which can be used for linear
|
|
|
|
* least squares
|
|
|
|
* @param A Matrix of size MxN
|
|
|
|
*
|
|
|
|
* Initialize the class with a MxN matrix. The constructor starts the
|
|
|
|
* QR decomposition. This class does not check the rank of the matrix.
|
|
|
|
* The user needs to make sure that rank(A) = N and M >= N.
|
|
|
|
*/
|
|
|
|
LeastSquaresSolver(const Matrix<Type, M, N> &A)
|
|
|
|
{
|
|
|
|
static_assert(M >= N, "Matrix dimension should be M >= N");
|
|
|
|
|
|
|
|
// Copy contentents of matrix A
|
|
|
|
_A = A;
|
|
|
|
|
|
|
|
for (size_t j = 0; j < N; j++) {
|
|
|
|
Type normx = Type(0);
|
|
|
|
for (size_t i = j; i < M; i++) {
|
|
|
|
normx += _A(i,j) * _A(i,j);
|
|
|
|
}
|
|
|
|
normx = sqrt(normx);
|
|
|
|
Type s = _A(j,j) > 0 ? Type(-1) : Type(1);
|
|
|
|
Type u1 = _A(j,j) - s*normx;
|
|
|
|
// prevent divide by zero
|
|
|
|
// also covers u1. normx is never negative
|
|
|
|
if (normx < Type(1e-8)) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
Type w[M] = {};
|
|
|
|
w[0] = Type(1);
|
|
|
|
for (size_t i = j+1; i < M; i++) {
|
|
|
|
w[i-j] = _A(i,j) / u1;
|
|
|
|
_A(i,j) = w[i-j];
|
|
|
|
}
|
|
|
|
_A(j,j) = s*normx;
|
|
|
|
_tau(j) = -s*u1/normx;
|
|
|
|
|
|
|
|
for (size_t k = j+1; k < N; k++) {
|
|
|
|
Type tmp = Type(0);
|
|
|
|
for (size_t i = j; i < M; i++) {
|
|
|
|
tmp += w[i-j] * _A(i,k);
|
|
|
|
}
|
|
|
|
for (size_t i = j; i < M; i++) {
|
|
|
|
_A(i,k) -= _tau(j) * w[i-j] * tmp;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief qtb Calculate Q^T * b
|
|
|
|
* @param b
|
|
|
|
* @return Q^T*b
|
|
|
|
*
|
|
|
|
* This function calculates Q^T * b. This is useful for the solver
|
|
|
|
* because R*x = Q^T*b.
|
|
|
|
*/
|
|
|
|
Vector<Type, M> qtb(const Vector<Type, M> &b) {
|
|
|
|
Vector<Type, M> qtbv = b;
|
|
|
|
|
|
|
|
for (size_t j = 0; j < N; j++) {
|
|
|
|
Type w[M];
|
|
|
|
w[0] = Type(1);
|
|
|
|
// fill vector w
|
|
|
|
for (size_t i = j+1; i < M; i++) {
|
|
|
|
w[i-j] = _A(i,j);
|
|
|
|
}
|
|
|
|
Type tmp = Type(0);
|
|
|
|
for (size_t i = j; i < M; i++) {
|
|
|
|
tmp += w[i-j] * qtbv(i);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = j; i < M; i++) {
|
|
|
|
qtbv(i) -= _tau(j) * w[i-j] * tmp;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return qtbv;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Solve Ax=b for x
|
|
|
|
* @param b
|
|
|
|
* @return Vector x
|
|
|
|
*
|
|
|
|
* Find x in the equation Ax = b.
|
|
|
|
* A is provided in the initializer of the class.
|
|
|
|
*/
|
|
|
|
Vector<Type, N> solve(const Vector<Type, M> &b) {
|
|
|
|
Vector<Type, M> qtbv = qtb(b);
|
|
|
|
Vector<Type, N> x;
|
|
|
|
|
|
|
|
// size_t is unsigned and wraps i = 0 - 1 to i > N
|
|
|
|
for (size_t i = N - 1; i < N; i--) {
|
|
|
|
printf("i %d\n", static_cast<int>(i));
|
|
|
|
x(i) = qtbv(i);
|
|
|
|
for (size_t r = i+1; r < N; r++) {
|
|
|
|
x(i) -= _A(i,r) * x(r);
|
|
|
|
}
|
|
|
|
// divide by zero, return vector of zeros
|
|
|
|
if (isEqualF(_A(i,i), Type(0), Type(1e-8))) {
|
|
|
|
for (size_t z = 0; z < N; z++) {
|
|
|
|
x(z) = Type(0);
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
x(i) /= _A(i,i);
|
|
|
|
}
|
|
|
|
return x;
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
Matrix<Type, M, N> _A;
|
|
|
|
Vector<Type, N> _tau;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace matrix
|
|
|
|
|
|
|
|
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
|