/* * Copyright (c) 2009-2013, Peter Abeles. All Rights Reserved. * * This file is part of Efficient Java Matrix Library (EJML). * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mikera.matrixx.solve.impl; import mikera.matrixx.AMatrix; import mikera.matrixx.Matrix; import mikera.matrixx.decompose.ICholeskyLDUResult; import mikera.matrixx.decompose.impl.chol.CholeskyLDU; /** * @author Peter Abeles */ public class CholeskyLDUSolver { protected Matrix A; protected int numRows; protected int numCols; private ICholeskyLDUResult ans; private int n; private double vv[]; private double el[]; private double d[]; public boolean setA(AMatrix _A) { // _setA(A); this.A = Matrix.create(_A); this.numRows = A.rowCount(); this.numCols = A.columnCount(); ans = CholeskyLDU.decompose(A); if( ans != null ){ n = A.columnCount(); // vv = decomp._getVV(); vv = new double[A.rowCount()]; el = ans.getL().toMatrix().data; d = ans.getD().getLeadingDiagonal().toDoubleArray(); return true; } else { return false; } } public double quality() { return Math.abs(diagProd(ans.getL())); } private double diagProd(AMatrix m) { double prod = 1.0; int diagonalLength = m.rowCount(); for(int i=0; i<diagonalLength; i++) { prod *= m.get(i, i); } return prod; } /** * <p> * Using the decomposition, finds the value of 'X' in the linear equation below:<br> * * A*x = b<br> * * where A has dimension of n by n, x and b are n by m dimension. * </p> * <p> * *Note* that 'b' and 'x' can be the same matrix instance. * </p> * * @param B A matrix that is n by m. Not modified. * @param X An n by m matrix where the solution is writen to. Modified. */ public AMatrix solve(AMatrix B) { Matrix X = Matrix.create(B.rowCount(), B.columnCount()); if( B.columnCount() != X.columnCount() && B.rowCount() != n && X.rowCount() != n) { throw new IllegalArgumentException("Unexpected matrix size"); } int numCols = B.columnCount(); double dataB[] = B.toMatrix().data; double dataX[] = X.data; for( int j = 0; j < numCols; j++ ) { for( int i = 0; i < n; i++ ) vv[i] = dataB[i*numCols+j]; solveInternal(); for( int i = 0; i < n; i++ ) dataX[i*numCols+j] = vv[i]; } return X; } /** * Used internally to find the solution to a single column vector. */ private void solveInternal() { // solve L*s=b storing y in x TriangularSolver.solveL(el,vv,n); // solve D*y=s for( int i = 0; i < n; i++ ) { vv[i] /= d[i]; } // solve L^T*x=y TriangularSolver.solveTranL(el,vv,n); } /** * returns the matrix 'inv' equal to the inverse of the matrix that was decomposed. * * @return inverse of matrix that was decomposed */ public AMatrix invert() { Matrix inv = Matrix.create(numRows, numCols); if( inv.rowCount() != n || inv.columnCount() != n ) { throw new RuntimeException("Unexpected matrix dimension"); } double a[] = inv.data; // solve L*z = b for( int i =0; i < n; i++ ) { for( int j = 0; j <= i; j++ ) { double sum = (i==j) ? 1.0 : 0.0; for( int k=i-1; k >=j; k-- ) { sum -= el[i*n+k]*a[j*n+k]; } a[j*n+i] = sum; } } // solve D*y=z for( int i =0; i < n; i++ ) { double inv_d = 1.0/d[i]; for( int j = 0; j <= i; j++ ) { a[j*n+i] *= inv_d; } } // solve L^T*x = y for( int i=n-1; i>=0; i-- ) { for( int j = 0; j <= i; j++ ) { double sum = (i<j) ? 0 : a[j*n+i]; for( int k=i+1;k<n;k++) { sum -= el[k*n+i]*a[j*n+k]; } a[i*n+j] = a[j*n+i] = sum; } } return inv; } }