/*
* Copyright 2011-2013, by Vladimir Kostyukov and Contributors.
*
* This file is part of la4j project (http://la4j.org)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Contributor(s): -
*
*/
package org.la4j.linear;
import org.la4j.LinearAlgebra;
import org.la4j.decomposition.MatrixDecompositor;
import org.la4j.Matrix;
import org.la4j.Vector;
import org.la4j.Vectors;
public class LeastSquaresSolver extends AbstractSolver implements LinearSystemSolver {
private static final long serialVersionUID = 4071505L;
// Matrices from RAW_QR decomposition
private final Matrix qr;
private final Matrix r;
public LeastSquaresSolver(Matrix a) {
super(a);
// we use QR for this
MatrixDecompositor decompositor = a.withDecompositor(LinearAlgebra.RAW_QR);
Matrix[] qrr = decompositor.decompose();
// TODO: Do something with it.
this.qr = qrr[0];
this.r = qrr[1];
}
@Override
public Vector solve(Vector b) {
ensureRHSIsCorrect(b);
int n = unknowns();
int m = equations();
// check whether the matrix is full-rank or not
for (int i = 0; i < r.rows(); i++) {
if (r.get(i, i) == 0.0) {
fail("This system can not be solved: coefficient matrix is rank deficient.");
}
}
Vector x = b.copy();
for (int j = 0; j < n; j++) {
double acc = 0.0;
for (int i = j; i < m; i++) {
acc += qr.get(i, j) * x.get(i);
}
acc = -acc / qr.get(j, j);
for (int i = j; i < m; i++) {
x.updateAt(i, Vectors.asPlusFunction(acc * qr.get(i, j)));
}
}
for (int j = n - 1; j >= 0; j--) {
x.updateAt(j, Vectors.asDivFunction(r.get(j, j)));
for (int i = 0; i < j; i++) {
x.updateAt(i, Vectors.asMinusFunction(x.get(j) * qr.get(i, j)));
}
}
return x.slice(0, n);
}
@Override
public boolean applicableTo(Matrix matrix) {
return matrix.rows() >= matrix.columns();
}
}