/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.decompositions;
import org.apache.ignite.ml.math.Destroyable;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
import org.apache.ignite.ml.math.functions.Functions;
import static org.apache.ignite.ml.math.util.MatrixUtil.copy;
import static org.apache.ignite.ml.math.util.MatrixUtil.like;
/**
* For an {@code m x n} matrix {@code A} with {@code m >= n}, the QR decomposition
* is an {@code m x n} orthogonal matrix {@code Q} and an {@code n x n} upper
* triangular matrix {@code R} so that {@code A = Q*R}.
*/
public class QRDecomposition implements Destroyable {
/** */
private final Matrix q;
/** */
private final Matrix r;
/** */
private final Matrix mType;
/** */
private final boolean fullRank;
/** */
private final int rows;
/** */
private final int cols;
/** */
private double threshold;
/**
* @param v Value to be checked for being an ordinary double.
*/
private void checkDouble(double v) {
if (Double.isInfinite(v) || Double.isNaN(v))
throw new ArithmeticException("Invalid intermediate result");
}
/**
* Constructs a new QR decomposition object computed by Householder reflections.
* Threshold for singularity check used in this case is 0.
*
* @param mtx A rectangular matrix.
*/
public QRDecomposition(Matrix mtx) {
this(mtx, 0.0);
}
/**
* Constructs a new QR decomposition object computed by Householder reflections.
*
* @param mtx A rectangular matrix.
* @param threshold Value used for detecting singularity of {@code R} matrix in decomposition.
*/
public QRDecomposition(Matrix mtx, double threshold) {
assert mtx != null;
rows = mtx.rowSize();
int min = Math.min(mtx.rowSize(), mtx.columnSize());
cols = mtx.columnSize();
mType = like(mtx, 1, 1);
Matrix qTmp = copy(mtx);
boolean fullRank = true;
r = like(mtx, min, cols);
this.threshold = threshold;
for (int i = 0; i < min; i++) {
Vector qi = qTmp.viewColumn(i);
double alpha = qi.kNorm(2);
if (Math.abs(alpha) > Double.MIN_VALUE)
qi.map(Functions.div(alpha));
else {
checkDouble(alpha);
fullRank = false;
}
r.set(i, i, alpha);
for (int j = i + 1; j < cols; j++) {
Vector qj = qTmp.viewColumn(j);
double norm = qj.kNorm(2);
if (Math.abs(norm) > Double.MIN_VALUE) {
double beta = qi.dot(qj);
r.set(i, j, beta);
if (j < min)
qj.map(qi, Functions.plusMult(-beta));
}
else
checkDouble(norm);
}
}
if (cols > min)
q = qTmp.viewPart(0, rows, 0, min).copy();
else
q = qTmp;
this.fullRank = fullRank;
}
/** {@inheritDoc} */
@Override public void destroy() {
q.destroy();
r.destroy();
mType.destroy();
}
/**
* Gets orthogonal factor {@code Q}.
*/
public Matrix getQ() {
return q;
}
/**
* Gets triangular factor {@code R}.
*/
public Matrix getR() {
return r;
}
/**
* Returns whether the matrix {@code A} has full rank.
*
* @return true if {@code R}, and hence {@code A} , has full rank.
*/
public boolean hasFullRank() {
return fullRank;
}
/**
* Least squares solution of {@code A*X = B}; {@code returns X}.
*
* @param mtx A matrix with as many rows as {@code A} and any number of cols.
* @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}.
* @throws IllegalArgumentException if {@code B.rows() != A.rows()}.
*/
public Matrix solve(Matrix mtx) {
if (mtx.rowSize() != rows)
throw new IllegalArgumentException("Matrix row dimensions must agree.");
int cols = mtx.columnSize();
Matrix r = getR();
checkSingular(r, threshold, true);
Matrix x = like(mType, this.cols, cols);
Matrix qt = getQ().transpose();
Matrix y = qt.times(mtx);
for (int k = Math.min(this.cols, rows) - 1; k >= 0; k--) {
// X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as =
x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1 / r.get(k, k)));
if (k == 0)
continue;
// Y[0:(k-1),] -= R[0:(k-1),k] * X[k,]
Vector rCol = r.viewColumn(k).viewPart(0, k);
for (int c = 0; c < cols; c++)
y.viewColumn(c).viewPart(0, k).map(rCol, Functions.plusMult(-x.get(k, c)));
}
return x;
}
/**
* Least squares solution of {@code A*X = B}; {@code returns X}.
*
* @param vec A vector with as many rows as {@code A}.
* @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}.
* @throws IllegalArgumentException if {@code B.rows() != A.rows()}.
*/
public Vector solve(Vector vec) {
Matrix res = solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec));
return vec.like(res.rowSize()).assign(res.viewColumn(0));
}
/**
* Returns a rough string rendition of a QR.
*/
@Override public String toString() {
return String.format("QR(%d x %d, fullRank=%s)", rows, cols, hasFullRank());
}
/**
* Check singularity.
*
* @param r R matrix.
* @param min Singularity threshold.
* @param raise Whether to raise a {@link SingularMatrixException} if any element of the diagonal fails the check.
* @return {@code true} if any element of the diagonal is smaller or equal to {@code min}.
* @throws SingularMatrixException if the matrix is singular and {@code raise} is {@code true}.
*/
private static boolean checkSingular(Matrix r, double min, boolean raise) {
// TODO: Not a very fast approach for distributed matrices. would be nice if we could independently check
// parts on different nodes for singularity and do fold with 'or'.
final int len = r.columnSize();
for (int i = 0; i < len; i++) {
final double d = r.getX(i, i);
if (Math.abs(d) <= min)
if (raise)
throw new SingularMatrixException("Number is too small (%f, while " +
"threshold is %f). Index of diagonal element is (%d, %d)", d, min, i, i);
else
return true;
}
return false;
}
}