/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.sumproduct.customFactors;
import static com.analog.lyric.math.MoreMatrixUtils.*;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
public class MutlivariateGaussianMatrixProduct
{
// TODO: use Apache math Matrix implementation
private int M, N;
private double [/*M*/][/*N*/] A_clean; //a form of A with zero singular values set to eps
private double [/*M*/][/*N*/] A_pinv; //a (left or right or both) inverse of A where zero singular values are inverted to 1/eps
//minimum value for small eigenvalues or 1/(max value)
private static final double eps = 1e-7;
//Initializer
public MutlivariateGaussianMatrixProduct(double[][] A)
{
int i; //,m;
M = A.length; N = A[0].length;
/* Here we precompute and store matrices for future message computations.
* First, compute an SVD of the matrix A using EigenDecompositions of A*A^T and A^T*A
* This way, we get nullspaces for free along with regularized inverse.
*/
RealMatrix Amat = wrapRealMatrix(A);
SingularValueDecomposition svd = new SingularValueDecomposition(Amat);
RealMatrix tmp = svd.getVT();
tmp = svd.getS().multiply(tmp);
tmp = svd.getU().multiply(tmp);
A_clean = matrixGetDataRef(tmp);
RealMatrix ST = svd.getS().transpose();
int numS = Math.min(ST.getColumnDimension(),ST.getRowDimension());
for (i = 0; i < numS; i++)
{
double d = ST.getEntry(i,i);
if (d < eps)
d = eps;
else if (d > 1/eps)
d = 1/eps;
ST.setEntry(i,i, 1.0/d);
}
A_pinv = matrixGetDataRef(svd.getV().multiply(ST.multiply(svd.getUT())));
}
public void ComputeMsg(MultivariateNormalParameters inMsg, MultivariateNormalParameters outMsg, char direction /* F for Forward, R for Reverse */)
{
// TODO: clean this code up!
assert (direction == 'F' || direction == 'R'); //only two directions possible
// Special case for deterministic input
final double[] deterministicInput = inMsg.toDeterministicValueUnsafe();
if (deterministicInput != null)
{
final double[] output = direction == 'F' ?
new Array2DRowRealMatrix(A_clean, false).operate(deterministicInput) :
new Array2DRowRealMatrix(A_pinv, false).preMultiply(deterministicInput);
outMsg.setDeterministic(output);
return;
}
// Special case for null input
if (inMsg.isNull())
{
outMsg.setNull();
return;
}
int m,n;
//multiGaBPMsg outMsg;
//TODO: this is really hacky! the conversion to and from the inverse makes sure the matrix doesn't
//grow too large.
inMsg.getCovariance();
inMsg.getInformationMatrix();
if(direction == 'F') //Forward matrix multiply
{
//outMsg = new multiGaBPMsg(M); //Output matrix is MxM
assert(inMsg.getVectorLength() == N); //dimensions should match for it to be a valid forward matrix multiply
if(!inMsg.isInInformationForm()) //We were given a covariance form inMsg
{
double [] tmpVector = new double[M];
double [] inMsgVector = inMsg.getMean();
//Compute mean vector output: m_y = A*m_x
for(m=0;m<M;m++)
{
tmpVector[m] = 0;
for(n=0;n<N;n++)
tmpVector[m] += A_clean[m][n] * inMsgVector[n];
}
if (inMsg.hasDeterministicValue())
{
}
else
{
//outMsg.Type = 0; //We give the same output form (covariance)
//Calculate A * V * A^T
double [][] covar = inMsg.getCovariance();
double [][] tmpMat = MatrixMult(A_clean, MatrixMult(covar, Transpose(A_clean)));
//Incorporate left nullspace term: C*C^T*eps
for (int i = 0; i < tmpMat.length; i++)
tmpMat[i][i] += eps;
// if(LeftNullTerm != null)
// for(m=0;m<M;m++)
// for(n=0;n<M;n++)
// tmpMat[m][n] += LeftNullTerm[m][n] * eps;
outMsg.setMeanAndCovariance(tmpVector,tmpMat);
}
// System.out.println("calculated info vector");
// outMsg.getInformationVector();
// System.out.println("calculated means");
// printVector(outMsg.getMeans());
if (true)
throw new DimpleException("is this tested?");
}else{ //We were given an information form inMsg
//outMsg.Type = 1; //We give the same output form (information)
//Compute A^-T * W * A^-1
//inMsg.getMeans();
//printMatrix(inMsg.getInformationMatrix());
double [][] tmpMat = MatrixMult(Transpose(A_pinv), MatrixMult(inMsg.getInformationMatrix(), A_pinv));
//Incorporate left nullspace term: C*C^T/eps
// if(LeftNullTerm != null)
// for(m=0;m<M;m++)
// for(n=0;n<M;n++)
// tmpMat[m][n] += LeftNullTerm[m][n] / eps;
for (int i = 0; i < tmpMat.length; i++)
tmpMat[i][i] += eps;
double [] tmpVector = new double[M];
double [] inMsgVector = inMsg.getInformationVector();
//Compute information vector output: h_y = A^-T * h_x
for(m=0;m<M;m++)
{
tmpVector[m] = 0;
for(n=0;n<N;n++)
tmpVector[m] += A_pinv[n][m] * inMsgVector[n];
}
outMsg.setInformation(tmpVector,tmpMat);
}
}else{ // Reverse matrix multiply
//outMsg = new multiGaBPMsg(N); //Output matrix is NxN
assert(inMsg.getVectorLength() == M); //dimensions should match for it to be a valid reverse matrix multiply
if(!inMsg.isInInformationForm()) //We were given a covariance form inMsg
{
//outMsg.Type = 0; //We give the same output form (covariance)
//Compute A^-1 * V * A^-T
double [][] tmpMat = MatrixMult(A_pinv, MatrixMult(inMsg.getCovariance(), Transpose(A_pinv)));
//Incorporate nullspace term: B^T*B/eps
// if(NullTerm != null)
// for(m=0;m<M;m++)
// for(n=0;n<M;n++)
// tmpMat[m][n] += NullTerm[m][n] / eps;
for (int i = 0; i < tmpMat.length; i++)
tmpMat[i][i] += eps;
double [] tmpVector = new double[N];
double [] inMsgVector = inMsg.getMean();
//Compute mean vector: m_x = A^-1 * m_y
for(m=0;m<N;m++)
{
tmpVector[m] = 0;
for(n=0;n<M;n++)
tmpVector[m] += A_pinv[m][n] * inMsgVector[n];
}
outMsg.setMeanAndCovariance(tmpVector,tmpMat);
if (true)
throw new DimpleException("is this tested?");
}else{ //We were given an information form inMsg
//outMsg.Type = 1; //We give the same output form (information)
//Compute A^T*W*A
double [][] tmpMat = MatrixMult(Transpose(A_clean), MatrixMult(inMsg.getInformationMatrix(), A_clean));
//Incorporate nullspace term: B^T*B*eps
// if(NullTerm != null)
// for(m=0;m<M;m++)
// for(n=0;n<M;n++)
// tmpMat[m][n] += NullTerm[m][n] * eps;
for (int i = 0; i < tmpMat.length; i++)
tmpMat[i][i] += eps;
double [] tmpVector = new double[N];
double [] inMsgVector = inMsg.getInformationVector();
//Compute information vector: h_x = A^T * h_y
for(m=0;m<N;m++)
{
tmpVector[m] = 0;
for(n=0;n<M;n++)
tmpVector[m] += A_clean[n][m] * inMsgVector[n];
}
outMsg.setInformation(tmpVector, tmpMat);
}
}
}
public static void printVector (double [] vector)
{
for (int i = 0; i < vector.length; i++)
{
System.out.print(vector[i] + ", ");
}
System.out.println();
}
public static void printMatrix(double [][] matrix)
{
System.out.println("[");
for (int i = 0; i < matrix.length; i++)
{
printVector(matrix[i]);
}
System.out.println("]");
}
//Does an actual multiplication of Matrix A * Matrix B
public double[][] MatrixMult(double[][] A, double[][] B)
{
return new Array2DRowRealMatrix(A, false).multiply(new Array2DRowRealMatrix(B, false)).getDataRef();
}
public double[][] Transpose(double[][] A)
{
return ((Array2DRowRealMatrix)new Array2DRowRealMatrix(A, false).transpose()).getDataRef();
}
}