/*******************************************************************************
* Copyright 2012-2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.core.parameterizedMessages;
import static com.analog.lyric.math.MoreMatrixUtils.*;
import static java.lang.String.*;
import static java.util.Objects.*;
import java.io.PrintStream;
import java.util.AbstractList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.RealJoint;
import com.analog.lyric.math.LyricEigenvalueDecomposition;
import com.analog.lyric.util.misc.Matlab;
import Jama.Matrix;
@Matlab(wrapper="MultivariateNormalParameters")
public class MultivariateNormalParameters extends ParameterizedMessageBase
{
/*-------
* State
*/
// TODO : store eigendecomposition
// TODO : use Apache math Matrix implementation
private static final long serialVersionUID = 1L;
// FIXME : can we make this smaller? I set this experimentally on the amount of error
// produced when running MATLAB test alogRolledupGraphs/testMultivariateDataSource
/**
* Determines min eigenvalue of covariance/information matrix.
*/
public static final double MIN_EIGENVALUE = 1e-9;
protected static final double LOG_2PI = Math.log(2*Math.PI);
private int _size = 0;
private double [] _infoVector = ArrayUtil.EMPTY_DOUBLE_ARRAY;
private double [] _mean = ArrayUtil.EMPTY_DOUBLE_ARRAY;
private double [][] _matrix = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY;
/**
* If known to be diagonal, this is the precision values along the diagonal.
* <p>
* Only valid if {@link #_isDiagonal} && {@link #_isDiagonalComputed}.
*/
private double[] _precision = ArrayUtil.EMPTY_DOUBLE_ARRAY;
private double[] _variance = ArrayUtil.EMPTY_DOUBLE_ARRAY;
private boolean _isInInformationForm;
private boolean _isDiagonal = false;
private boolean _isDiagonalComputed = false;
/*--------------
* Constructors
*/
public MultivariateNormalParameters(double[] mean, double[][] covariance)
{
setMeanAndCovariance(mean.clone(), cloneMatrix(covariance));
}
public MultivariateNormalParameters(double[] vector, double[][] matrix, boolean informationForm)
{
if (informationForm)
{
setInformation(vector.clone(), cloneMatrix(matrix));
}
else
{
setMeanAndCovariance(vector.clone(), cloneMatrix(matrix));
}
}
public MultivariateNormalParameters(double[] mean, double[] variance)
{
setMeanAndVariance(mean, variance);
}
public MultivariateNormalParameters(List<NormalParameters> normals)
{
setDiagonal(normals);
}
/**
* Multivariate normal with specified number of dimensions, zero mean, and infinite covariance.
* @param dimensions a positive number
* @since 0.08
*/
public MultivariateNormalParameters(int dimensions)
{
this(new double[dimensions], arrayOf(dimensions, Double.POSITIVE_INFINITY));
}
/**
* Multivariate normal with dimensions matching variable domain, zero mean, and very large covariance.
* @param var
* @since 0.08
*/
public MultivariateNormalParameters(RealJoint var)
{
this(var.getDomain().getDimensions());
}
public MultivariateNormalParameters(MultivariateNormalParameters other) // Copy constructor
{
set(other);
}
@Override
public MultivariateNormalParameters clone()
{
return new MultivariateNormalParameters(this);
}
public final void setMeanAndCovariance(double[] mean, double[][] covariance)
{
// validateMatrix(covariance);
_size = mean.length;
_infoVector = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_mean = mean.clone();
_matrix = cloneMatrix(covariance);
_precision = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_variance = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_isInInformationForm = false;
_isDiagonalComputed = false;
forgetNormalizationEnergy();
}
public final void setMeanAndVariance(double[] mean, double[] variance)
{
set(mean.clone(), variance.clone(), false);
}
private final void set(double[] meanOrInfo, double[] varianceOrPrecision, boolean informationForm)
{
final int n = _size = meanOrInfo.length;
final double[] infoOrMean = new double[n];
final double[] precisionOrVariance = new double[n];
if (ArrayUtil.onlyContains(varianceOrPrecision, 0.0))
{
// All zero => inverse is infinite
Arrays.fill(precisionOrVariance, Double.POSITIVE_INFINITY);
Arrays.fill(infoOrMean, Double.POSITIVE_INFINITY);
}
else if (ArrayUtil.onlyContains(varianceOrPrecision, Double.POSITIVE_INFINITY))
{
// Infinite => inverse is zero
}
else
{
// Except for the all zero/infinite case, condition eigenvalues to not be too small/large FIXME hacky
for (int i = 0; i < _size; ++i)
{
double x = varianceOrPrecision[i], inv = 1/x;
if (x < MIN_EIGENVALUE)
{
varianceOrPrecision[i] = MIN_EIGENVALUE;
inv = 1/MIN_EIGENVALUE;
}
else if (inv < MIN_EIGENVALUE)
{
varianceOrPrecision[i] = 1/MIN_EIGENVALUE;
inv = MIN_EIGENVALUE;
}
precisionOrVariance[i] = inv;
infoOrMean[i] = meanOrInfo[i] * inv;
}
}
if (informationForm)
{
_infoVector = meanOrInfo;
_mean = infoOrMean;
_precision = varianceOrPrecision;
_variance = precisionOrVariance;
}
else
{
_mean = meanOrInfo;
_infoVector = infoOrMean;
_variance = varianceOrPrecision;
_precision = precisionOrVariance;
}
_matrix = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY;
_isInInformationForm = informationForm;
_isDiagonal = true;
_isDiagonalComputed = true;
forgetNormalizationEnergy();
}
public final void setDiagonal(List<NormalParameters> normals)
{
final int n = normals.size();
double[] means = new double[n];
double[] variances = new double[n];
for (int i = 0; i < n; ++i)
{
final NormalParameters normal = normals.get(i);
means[i] = normal.getMean();
variances[i] = normal.getVariance();
}
setMeanAndVariance(means, variances);
}
public final void setInformation(double[] informationVector, double[][] informationMatrix)
{
// validateMatrix(informationMatrix);
_size = informationVector.length;
_infoVector = informationVector.clone();
_mean = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_matrix = cloneMatrix(informationMatrix);
_isInInformationForm = true;
_precision = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_variance = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_isDiagonalComputed = false;
forgetNormalizationEnergy();
}
// Set from another parameter set without first extracting the components or determining which form
public final void set(MultivariateNormalParameters other)
{
_size = other._size;
_mean = ArrayUtil.cloneNonNullArray(other._mean);
_infoVector = ArrayUtil.cloneNonNullArray(other._infoVector);
_precision = ArrayUtil.cloneNonNullArray(other._precision);
_variance = ArrayUtil.cloneNonNullArray(other._variance);
_matrix = cloneMatrix(other._matrix);
_isInInformationForm = other._isInInformationForm;
_isDiagonal = other._isDiagonal;
_isDiagonalComputed = other._isDiagonalComputed;
copyNormalizationEnergy(other);
}
/*-----------------
* IEquals methods
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (other instanceof MultivariateNormalParameters)
{
MultivariateNormalParameters that = (MultivariateNormalParameters)other;
if (!(super.objectEquals(other) && _size == that._size && isDiagonal() == that.isDiagonal()))
return false;
if (isDiagonal())
{
return Arrays.equals(_mean, that._mean) && Arrays.equals(_variance, that._variance);
}
else if (isInInformationForm() == that.isInInformationForm())
{
final double[][] thisMatrix = _matrix;
final double[][] thatMatrix = that._matrix;
for (int i = thisMatrix.length; --i>=0;)
{
if (!Arrays.equals(thisMatrix[i], thatMatrix[i]))
{
return false;
}
}
return true;
}
}
return false;
}
/*----------------------
* IUnaryFactorFunction
*/
@Override
public double evalEnergy(Value value)
{
final int n = _size;
final double[] mean = getMean();
double[] x = value.getDoubleArray();
if (isDiagonal())
{
double energy = 0.0;
final double[] precisions = _precision;
for (int i = n; --i>=0;)
{
final double precision = precisions[i];
if (precision != 0.0)
{
final double diff = x[i] - mean[i];
energy += diff * diff * precision;
}
}
return energy * .5;
}
// TODO - support degenerate covariance case
x = x.clone();
for (int i = n; --i>=0;)
x[i] -= mean[i];
final double[][] informationMatrix = getInformationMatrix();
double colSum = 0;
for (int row = 0; row < n; row++)
{
double rowSum = 0;
final double[] informationMatrixRow = informationMatrix[row];
for (int col = 0; col < n; col++)
rowSum += informationMatrixRow[col] * x[col]; // Matrix * vector
colSum += rowSum * x[row]; // Vector * vector
}
return colSum * .5;
}
/*--------------------
* IPrintable methods
*/
@Override
public void print(PrintStream out, int verbosity)
{
if (verbosity < 0)
{
return;
}
String vectorLabel = "mean";
double[] vector = _mean;
if (vector.length == 0)
{
vector = _infoVector;
vectorLabel="info";
}
out.print("Normal(");
if (verbosity > 1)
{
out.println();
out.print(" ");
}
out.format("%s=[", vectorLabel);
for (int i = 0, end = vector.length; i < end; ++i)
{
if (i > 0)
{
out.print(',');
if (verbosity > 0)
{
out.print(' ');
}
}
if (verbosity > 0)
{
out.format("%d=", i);
}
out.format(verbosity > 1 ? "%.12g" : "%g", vector[i]);
}
out.print(']');
if (verbosity > 1)
{
out.println();
}
else
{
out.print(", ");
}
out.print(_isInInformationForm ? "precision=[" : "covariance=[");
if (isDiagonal())
{
double[] diagonal = _isInInformationForm ? _precision : _variance;
for (int i = 0, end = diagonal.length; i < end; ++i)
{
if (i > 0)
{
out.print(',');
if (verbosity > 0)
{
out.print(' ');
}
}
if (verbosity > 0)
{
out.format("%d=", i);
}
out.format("%g", diagonal[i]);
}
}
else
{
final int n = _matrix.length;
for (int row = 0; row < n; ++row)
{
if (row > 0)
{
out.print(";");
}
out.print("\n ");
for (int col = 0; col < n; ++col)
{
if (col > 0)
{
out.print(',');
}
out.format("%g", _matrix[row][col]);
}
}
}
out.print(']');
if (verbosity > 1)
{
out.println();
}
out.print(')');
}
/*-------------------------------
* IParameterizedMessage methods
*/
@Override
public void addFrom(IParameterizedMessage other)
{
addFrom((MultivariateNormalParameters)other);
}
public void addFrom(MultivariateNormalParameters other)
{
// That natural parameters are the information vector and matrix
//
// Special cases:
// - one message has all infinite precision - use corresponding means
// - both messages have all infinite precision - use existing mean if close enough
// - one message has all zero precision: use other means
if (other.isNull())
{
return;
}
if (isNull())
{
set(other);
return;
}
final int n = _size;
if (n != other._size)
{
throw new IllegalArgumentException(format("Cannot add from %s with different size",
other.getClass().getSimpleName()));
}
final boolean otherDiagonal = other.isDiagonal();
if (otherDiagonal && (other._precision.length == 0 || other._precision[0] == 0.0))
{
// Other message adds no information
return;
}
final boolean diagonal = isDiagonal();
if (diagonal && (_precision.length == 0 || _precision[0] == 0.0))
{
// This message provide no information, copy from other
set(other);
return;
}
final double[] value = toDeterministicValueUnsafe();
final double[] otherValue = other.toDeterministicValueUnsafe();
if (value != null)
{
if (otherValue != null)
{
// FIXME - compare values
}
else
{
// Deterministic value in this message overrides other message
}
return;
}
if (otherValue != null)
{
// Other message is deterministic. Set from that
setDeterministic(otherValue);
return;
}
forgetNormalizationEnergy();
if (diagonal && otherDiagonal)
{
// Just add diagonally
for (int i = 0; i < n; ++i)
{
_infoVector[i] += other._infoVector[i];
_precision[i] += other._precision[i];
_variance[i] = 1.0 / _precision[i];
_mean[i] = _infoVector[i] * _variance[i];
}
_matrix = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY;
return;
}
if (!diagonal && otherDiagonal)
{
toInformationFormat();
for (int i = 0; i < n; ++i)
{
_infoVector[i] += other._infoVector[i];
_matrix[i][i] += other._precision[i];
}
_mean = ArrayUtil.EMPTY_DOUBLE_ARRAY;
return;
}
if (diagonal) // && !otherDiagonal
{
_matrix = other.getInformationMatrix();
for (int i = 0; i < n; ++i)
{
_infoVector[i] += other._infoVector[i];
_matrix[i][i] += _precision[i];
}
_mean = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_precision = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_variance = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_isDiagonal = false;
_isDiagonalComputed = false;
_isInInformationForm = true;
return;
}
toInformationFormat();
other.toInformationFormat();
for (int i = 0; i <n; ++i)
{
_infoVector[i] += other._infoVector[i];
double[] row = _matrix[i];
double[] otherRow = other._matrix[i];
for (int j = 0; j < n; ++j)
{
row[j] += otherRow[j];
}
}
_mean = ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
/**
* {@inheritDoc}
* <p>
* For multivariate normal distributions, the formula is given by:
*
* <blockquote>
* ½ {
* trace(Σ<sub>Q</sub><sup>-1</sup>Σ<sub>P</sub>) +
* (μ<sub>Q</sub>-μ<sub>P</sub>)<sup>T</sup>Σ<sub>Q</sub><sup>-1</sup>(μ<sub>Q</sub>-μ<sub>P</sub>)
* -K - ln(det(Σ<sub>P</sub>)/det(Σ<sub>Q</sub>)))
* }
* </blockquote>
* Note that this assumes that the determinants of the covariance matrices are non-zero.
*/
@Override
public double computeKLDivergence(IParameterizedMessage that)
{
if (that instanceof MultivariateNormalParameters)
{
// http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback.E2.80.93Leibler_divergence
//
// K: size of vectors, # rows/columns of matrices
// up, uq: vectors of means for P and Q
// CP, CQ: covariance matrices for P and Q
// inv(x): inverse of x
// det(x): determinant of x
// tr(x): trace of x
// x': transpose of x
//
// KL(P|Q) == .5 * ( tr(inv(CQ) * CP) + (uq - up)' * inv(CQ) * (uq - up) - K - ln(det(CP)/det(CQ)) )
//
final MultivariateNormalParameters P = this, Q = (MultivariateNormalParameters)that;
final int K = P.getVectorLength();
assertSameSize(Q.getVectorLength());
if (P.isDiagonal() && Q.isDiagonal())
{
// If both are diagonal, we can simply add up the KL for the univariate cases along the diagonal.
final double[] Pmeans = P._mean, Qmeans = Q._mean;
final double[] Pprecisions = P._precision, Qprecisions = Q._precision;
double kl = 0.0;
for (int i = 0; i < K; ++i)
{
kl += NormalParameters.computeKLDiverence(Pmeans[i], Pprecisions[i], Qmeans[i], Qprecisions[i]);
}
return kl;
}
// TODO - if we ever start storing the eigendecomposition in this object, we can simply use
// the eigenvalues to efficiently compute the trace and determinants. Perhaps it would be worthwhile
// to save the eigenvalues if nothing else.
RealVector mP = wrapRealVector(P.getMean());
RealVector mQ = wrapRealVector(Q.getMean());
RealMatrix CP = wrapRealMatrix(P.getCovariance());
RealMatrix CQinv = wrapRealMatrix(Q.getInformationMatrix());
RealVector mdiff = mQ.subtract(mP);
// FIXME: do we need to worry about singular covariance matrices?
double divergence = -K;
// trace of product of matrices is equivalent to the dot-product of the vectorized versions
// of the matrices - this is much faster than doing the actual matrix product
// divergence += CQinv.multiply(CP).trace();
for (int i = 0; i < K; ++i)
for (int j = 0; j < K; ++j)
divergence += CQinv.getEntry(i, j) * CP.getEntry(i, j);
divergence += CQinv.preMultiply(mdiff).dotProduct(mdiff);
divergence -= Math.log(new EigenDecomposition(CP).getDeterminant() *
new EigenDecomposition(CQinv).getDeterminant());
return Math.abs(divergence/2); // use abs to guard against precision errors causing this to go negative.
}
throw new IllegalArgumentException(String.format("Expected '%s' but got '%s'", getClass(), that.getClass()));
}
/**
* {@inheritDoc}
* <p>
* True if {@link #isDiagonal()} and {@link #getDiagonalVariance()} contains only zeros.
* @since 0.08
*/
@Override
public boolean hasDeterministicValue()
{
// Only need to check first element of variance, because setting diagonal format only allows zeros if all
// are zero.
return isDiagonal() && _variance.length > 0 && _variance[0] == 0.0;
}
@Override
public void setDeterministic(Value value)
{
setDeterministic(value.getDoubleArray());
}
/**
* {@inheritDoc}
* <p>
* If {@link #isDiagonal()} and {@link #getDiagonalVariance()} contains only zeros, this returns
* Value containing the {@linkplain #getMean() mean}.
* @since 0.08
* @see #toDeterministicValue()
* @see #toDeterministicValueUnsafe()
*/
@Override
public @Nullable Value toDeterministicValue(Domain domain)
{
double[] value = toDeterministicValue();
return value != null ? Value.create((RealJointDomain)domain, value) : null;
}
@Override
public void setFrom(IParameterizedMessage other)
{
MultivariateNormalParameters that = (MultivariateNormalParameters)other;
set(that);
}
/**
* {@inheritDoc}
* <p>
* Sets all means to zero and all covariances to infinity
*/
@Override
public final void setUniform()
{
setMeanAndVariance(new double[_size], arrayOf(_size, Double.POSITIVE_INFINITY));
toInformationFormat();
}
/*---------------
* Local methods
*/
public final double[] getMeans() {return getMean();} // For backward compatibility
@Matlab
public final double[] getMean()
{
if (_mean.length == 0)
{
toCovarianceFormat();
}
return ArrayUtil.cloneNonNullArray(_mean);
}
@Matlab
public final double [][] getCovariance()
{
toCovarianceFormat();
instantiateMatrix();
return cloneMatrix(_matrix);
}
/**
* If information/covariance matrices are diagonal, return the parameters for the indexed diagonal entry.
* @since 0.08
* @return {@link Normal} containing mean and precision with given index.
* @see #getDiagonalPrecision()
* @see #getDiagonalVariance()
*/
@Matlab
public @Nullable Normal getDiagonalNormal(int index)
{
return isDiagonal() ? new Normal(_mean[index], _precision[index]) : null;
}
public @Nullable List<Normal> getDiagonalNormals()
{
if (isDiagonal())
{
return new AbstractList<Normal>() {
@Override
public Normal get(int index)
{
return requireNonNull(getDiagonalNormal(index));
}
@Override
public int size()
{
return _size;
}
};
}
return null;
}
/**
* If information matrix is diagonal, returns its elements, else a zero length array.
* @since 0.08
*/
public final double[] getDiagonalPrecision()
{
isDiagonal();
return ArrayUtil.cloneNonNullArray(_precision);
}
/**
* If covariance matrix is diagonal, returns its elements, else a zero length array.
* @since 0.08
*/
public final double[] getDiagonalVariance()
{
isDiagonal();
return ArrayUtil.cloneNonNullArray(_variance);
}
@Matlab
public final double [] getInformationVector()
{
if (_infoVector.length == 0)
{
toInformationFormat();
}
return ArrayUtil.cloneNonNullArray(_infoVector);
}
@Matlab
public final double [][] getInformationMatrix()
{
toInformationFormat();
instantiateMatrix();
return cloneMatrix(_matrix);
}
public final int getVectorLength()
{
return _size;
}
/**
* True if information/covariance matrix only contains diagonal entries.
* <p>
* That is, the only non-zero matrix entries have matching column/row indices.
* <p>
* @since 0.08
*/
public final boolean isDiagonal()
{
if (!_isDiagonalComputed)
{
final double[][] matrix = _matrix;
final int n = _size;
boolean isDiagonal = true;
// NOTE: assumes matrix is symmetric
outer:
for (int i = 0; i < n; ++i)
{
double[] row = matrix[i];
for (int j = 0; j < i; ++j)
{
if (row[j] != 0.0)
{
isDiagonal = false;
break outer;
}
}
}
if (isDiagonal)
{
// Convert to compact diagonal form
double[] array = new double[n];
for (int i = 0; i < n; ++i)
{
array[i] = matrix[i][i];
}
set(_isInInformationForm ? _infoVector : _mean, array, _isInInformationForm);
return true;
}
_isDiagonal = false;
_isDiagonalComputed = true;
}
return _isDiagonal;
}
public final boolean isInInformationForm()
{
return _isInInformationForm;
}
@Override
public final boolean isNull()
{
return _size == 0 || isDiagonal() && _precision[0] == 0.0;
}
/**
* Sets the {@linkplain #getMean() mean} (and {@linkplain #getInformationVector()} information vector} to its
* negation.
*
* @since 0.08
*/
public void negateMean()
{
for (int i = _mean.length; --i>=0;)
_mean[i] = -_mean[i];
for (int i = _infoVector.length; --i>=0;)
_infoVector[i] = -_infoVector[i];
// Normalization does not depend on mean, so no need to reset it.
}
public void setDeterministic(double[] value)
{
setMeanAndVariance(value, new double[value.length]);
}
/**
* Returns unique deterministic value if any.
* <p>
* If {@link #isDiagonal()} and {@link #getDiagonalVariance()} contains only zeros, this returns
* the {@linkplain #getMean() mean}.
* @since 0.08
*/
public @Nullable double[] toDeterministicValue()
{
return hasDeterministicValue() ? getMean() : null;
}
/**
* Returns unique deterministic value if any without copying.
* <p>
* This is the same as {@link #toDeterministicValue()} but returns a pointer to the internal copy of
* the mean instead of copying it. The caller must make sure not to modify the array!
* @since 0.08
*/
public @Nullable double[] toDeterministicValueUnsafe()
{
return hasDeterministicValue() ? _mean : null;
}
/*---------
* Private
*/
private final double[][] cloneMatrix(double[][] matrix)
{
double[][] retval = new double[matrix.length][];
for (int i = 0; i < retval.length; i++)
retval[i] = matrix[i].clone();
return retval;
}
/**
* Force instantiation of {@link #_matrix} if not already done and in diagonal form.
*/
private final void instantiateMatrix()
{
final int n = _size;
if (_matrix.length != n && _isDiagonal && _isDiagonalComputed)
{
_matrix = new double[n][n];
final double[] diagonal = _isInInformationForm ? _precision : _variance;
for (int i = 0; i < n; ++i)
_matrix[i][i] = diagonal[i];
}
}
private final boolean isInfiniteIdentity(double[][] m)
{
for (int i = 0; i < m.length; i++)
{
if (!Double.isInfinite(m[i][i]))
return false;
}
return true;
}
@Override
protected double computeNormalizationEnergy()
{
double energy = 0;
if (isDiagonal())
{
// Simply add up the energies of the diagonals, skipping the zero entries
for (double tau : _precision)
{
if (tau != 0.0)
{
energy += Math.log(tau) - LOG_2PI;
}
}
}
else
{
// TODO - support degenerate covariance
double logdet = Math.log(new Jama.Matrix(_matrix).det());
if (!_isInInformationForm)
logdet = -logdet;
energy = logdet - _size * LOG_2PI;
}
return energy / 2;
}
private final void toCovarianceFormat()
{
if (_isInInformationForm)
{
toggleFormat();
}
}
private final void toInformationFormat()
{
if (!_isInInformationForm)
{
toggleFormat();
}
}
public void validate()
{
// _variance will only be non-empty if diagonal
for (double v : _variance)
{
if (v <= 0)
{
throw notPositiveDefinite();
}
}
validateMatrix(_matrix);
}
private void validateMatrix(double[][] m)
{
final int n = m.length;
boolean allZero = true;
for (double[] row : m )
{
for (double value : row)
{
if (value != value)
{
throw new DimpleException("Matrix contains a NaN value");
}
if (value != 0.0)
{
allZero = false;
}
}
}
if (allZero)
{
return;
}
boolean infiniteDiagonal = true;
for (int i = 0; i < n; ++i)
{
final double[] row = m[i];
if (row.length != n)
{
throw new DimpleException("Matrix is not square");
}
for (int j = 0; j < n; ++j)
{
final double vij = row[j];
if (j == i)
{
infiniteDiagonal &= (vij == Double.POSITIVE_INFINITY);
}
else
{
final double vji = m[j][i];
if (Math.abs(vji - vij) > 1e-10)
{
throw new DimpleException("Matrix is not symmetric at entry (%d,%d)", i, j);
}
}
}
}
if (infiniteDiagonal)
{
return;
}
if (n > 0)
{
EigenDecomposition eig = new EigenDecomposition(wrapRealMatrix(m));
for (double value : eig.getRealEigenvalues())
{
if (value <= 0)
{
throw notPositiveDefinite();
}
}
}
}
private RuntimeException notPositiveDefinite()
{
return new DimpleException("Matrix is not positive definite");
}
/**
* Toggles between mean/covariance format and information format, which uses the matrix
* inverse of the covariance matrix (this is also known as the precision or concentration matrix)
*
* @since 0.06
*/
private final void toggleFormat()
{
outer:
{
if (isDiagonal())
{
// Only need to update the matrix, if present
if (_matrix.length != 0)
{
final double[] diagonal = _isInInformationForm ? _variance : _precision;
for (int i = 0, n = _size; i < n; ++i)
_matrix[i][i] = diagonal[i];
}
break outer;
}
// TODO: consider using EJML or MTJ library instead of Jama
double[] newVector;
if (isInfiniteIdentity(_matrix))
{
//Handle the special case where variances are infinite
_matrix = new double[_size][_size];
newVector = new double[_size];
}
else
{
// FIXME - replace with Apache version if possible
// Currently, attempting to replace with this the equivalent Apache commons implementation
// causes the MATLAB Kalman filter tests to fail. Using the colt implementation also breaks
// the Kalman tests although the error is not quite as great. I am not sure whether this is
// because the Jama implementation is better in some way or because it deals with degenerate
// covariance matrices in a more appropriate manner for the Kalman case. [cbarber 2015-07-06]
LyricEigenvalueDecomposition eig = new LyricEigenvalueDecomposition(new Jama.Matrix(_matrix));
Matrix D = eig.getD();
Matrix V = eig.getV();
int N = D.getColumnDimension();
for (int i=0; i<N; i++) {
//Compute inverse of eigenvalues except for those less than eps we set to large constant.
double d = D.get(i,i);
d = invertEigenvalue(d);
D.set(i,i, d);
assert(d > 0); // Eigenvalues should always be positive for positive definite matrices
}
Matrix inv;
Matrix vec;
inv = V.times(D.times(V.transpose()));
vec = new Matrix(new double [][] {_isInInformationForm ? _infoVector : _mean}).transpose();
vec = inv.times(vec);
_matrix = inv.getArray();
newVector = vec.transpose().getArray()[0];
}
if (_isInInformationForm)
{
_mean = newVector;
}
else
{
_infoVector = newVector;
}
}
_isInInformationForm = !_isInInformationForm;
}
private static double[] arrayOf(int size, double value)
{
final double[] array = new double[size];
Arrays.fill(array, value);
return array;
}
protected void assertSameSize(int otherSize)
{
final int size = getVectorLength();
if (size != otherSize)
{
throw new IllegalArgumentException(
String.format("Incompatible vector sizes '%d' and '%d'", size, otherSize));
}
}
private double invertEigenvalue(double eigenvalue)
{
return Math.min(1/MIN_EIGENVALUE, 1/eigenvalue);
}
}