/*
* MultiDimensionalScalingMM.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.multidimensionalscaling.mm;
import dr.inference.multidimensionalscaling.MultiDimensionalScalingLikelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.operators.EllipticalSliceOperator;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.xml.*;
/**
* Created by msuchard on 12/15/15.
*/
public class MultiDimensionalScalingMM extends MMAlgorithm {
private final MultiDimensionalScalingLikelihood likelihood;
private final GaussianProcessRandomGenerator gp;
private final int P; // Embedding dimension
private final int Q; // Data dimension
private double[] XtX = null;
private double[] D = null;
private double[] distance = null;
final private double tolerance;
public MultiDimensionalScalingLikelihood getLikelihood() { return likelihood; }
public GaussianProcessRandomGenerator getGaussianProcess() { return gp; }
public double getTolerance() { return tolerance; }
public MultiDimensionalScalingMM(MultiDimensionalScalingLikelihood likelihood,
GaussianProcessRandomGenerator gp,
double tolerance) {
super();
this.likelihood = likelihood;
this.gp = gp;
this.P = likelihood.getMdsDimension();
this.Q = likelihood.getLocationCount();
this.tolerance = tolerance;
}
public void run() {
run(100000);
}
public void run(final int maxIterations) {
if (maxIterations == 0) return;
if (gp != null) {
double[][] precision = gp.getPrecisionMatrix();
setPrecision(precision);
}
// System.err.println("");
// System.err.println("weight: " + 1.0 / likelihood.getMDSPrecision());
weightTree = 1.0 / likelihood.getMDSPrecision();
double[] start = likelihood.getMatrixParameter().getParameterValues();
System.err.println("Start: " + printArray(start));
double penaltyStart = printLogObjective();
// EllipticalSliceOperator.transformPoint(mode, true, true, P);
// setParameterValues(likelihood.getMatrixParameter(), mode);
// printLogObjective();
double[] mode = null;
try {
mode = findMode(likelihood.getMatrixParameter().getParameterValues(),
tolerance, maxIterations);
} catch (NotConvergedException e) {
e.printStackTrace();
}
// System.err.println("Final: " + printArray(mode));
setParameterValues(likelihood.getMatrixParameter(), mode);
double penaltyEnd = printLogObjective();
System.err.println("Move: " + penaltyStart + " -> " + penaltyEnd + " : " + (penaltyEnd - penaltyStart));
//
// if (penaltyStart - penaltyEnd > 1E-1) {
// if (penaltyEnd < penaltyStart) {
// throw new RuntimeException("MM moved up-hill\n\tStart: " + penaltyStart + "\n\tEnd : " + penaltyEnd);
// System.err.println("Revert: MM moved up-hill?");
// setParameterValues(likelihood.getMatrixParameter(), start);
// double penaltyRevert = printLogObjective();
// System.err.println("End: " + penaltyEnd);
// System.err.println("revert : " + penaltyRevert);
// throw new RuntimeException("out");
//
// }
// EllipticalSliceOperator.transformPoint(mode, true, true, P);
//
// System.err.println("Final: " + printArray(mode));
//
// setParameterValues(likelihood.getMatrixParameter(), mode);
// printLogObjective();
// throw new RuntimeException("done");
}
private double printLogObjective() {
double logLike = likelihood.getLogLikelihood();
double logPenalty = gp.getLikelihood().getLogLikelihood();
double total = logLike;
if (weightTree != 0.0) {
total += logPenalty;
}
System.err.println("obj: " + total + " = " + logLike + " + " + logPenalty);
// return logPenalty;
return total;
}
private void setParameterValues(MatrixParameterInterface mat, double[] values) {
// for (int i = 0; i < values.length; ++i) {
// mat.setValue(i, values[i]);
// }
mat.setAllParameterValuesQuietly(values, 0);
mat.setParameterValueNotifyChangedAll(0, 0, values[0]); // Fire changed
// for (int i = 0; i < mat.getUniqueParameterCount(); ++i) {
// mat.getUniqueParameter(0).fireParameterChangedEvent();
// }
// mat.getUniqueParameter(0);
}
private double[] getDistanceMatrix() {
return likelihood.getObservations();
}
private void setPrecision(double[][] matrix) {
if (!ignoreGP) {
final int QP = matrix.length;
if (QP != this.Q * this.P) throw new IllegalArgumentException("Invalid dimensions");
precision = matrix;
// precision = new double[QP * QP];
// precisionSign = new int[QP * QP];
precisionStatistics = new double[QP];
for (int ik = 0; ik < QP; ++ik) {
double sum = 0.0;
for (int jl = 0; jl < QP; ++jl) {
// double value = weightTree * matrix[ik][jl];
if (ik != jl) {
sum += Math.abs(precision[ik][jl]);
}
// precisionSign[ik * QP + jl] = (value > 0.0) ? 1 : -1;
// precision[ik * QP + jl] = Math.abs(value);
}
precisionStatistics[ik] = sum;
}
}
}
protected void mmUpdate(final double[] current, double[] next) {
if (XtX == null) {
XtX = new double[Q * Q];
}
if (D == null) {
D = new double[Q * Q];
for (int i = 0; i < Q; ++i) {
D[i * Q + i] = 1.0; // To protect against divide-by-zero
}
}
if (distance == null) {
distance = getDistanceMatrix();
}
// Compute XtX
for (int i = 0; i < Q; ++i) {
for (int j = i; j < Q; ++j) {
double innerProduct = 0.0;
for (int k = 0; k < P; ++k) {
innerProduct += current[i * P + k] * current[j * P + k];
}
XtX[j * Q + i] = XtX[i * Q + j] = innerProduct;
}
}
// Compute D
for (int i = 0; i < Q; ++i) {
for (int j = i + 1; j < Q; ++j) { // TODO XtX is not a necessary intermediate
double norm2 = XtX[i * Q + i] + XtX[j * Q + j] - 2 * XtX[i * Q + j];
double norm = norm2 > 0.0 ? Math.sqrt(norm2) : 0.0;
D[j * Q + i] = D[i * Q + j] = Math.max(norm, 1E-10);
if (Double.isNaN(D[i * Q + j])) {
System.err.println("D NaN");
System.err.println(XtX[i * Q + i]);
System.err.println(XtX[j * Q + j]);
System.err.println(2 * XtX[i * Q + j]);
System.err.println(norm2);
System.err.println(norm);
System.exit(-1);
}
}
}
// Compute update
for (int i = 0; i < Q; ++i) { // TODO Embarrassingly parallel
for (int k = 0; k < P; ++k) { // TODO Embarrassingly parallel
final int ik = i * P + k;
// final int QP = Q * P;
double numerator = 0.0;
for (int j = 0; j < Q; ++j) {
double inc = 0.0;
if (i != j) {
// int add = (i != j) ? 1 : 0;
// double inc = add * distance[i * Q + j] * (current[i * P + k] - current[j * P + k]) / D[i * Q + j]
// + (current[i * P + k] + current[j * P + k]);
inc += distance[i * Q + j] * (current[i * P + k] - current[j * P + k]) / D[i * Q + j]
+ (current[i * P + k] + current[j * P + k]);
}
// inc = 0.0; // TODO Remove!
if (Double.isNaN(inc)) {
System.err.println("Bomb at " + i + " " + k + " " + j);
System.err.println("Distance = " + distance[i * Q + j]);
System.err.println("Ci = " + current[i * P + k]);
System.err.println("Cj = " + current[j * P + k]);
System.err.println("D = " + D[i * Q + j]);
System.exit(-1);
}
if (precision != null) {
for (int l = 0; l < P; ++l) {
final int jl = j * P + l;
final double prec = precision[ik][jl];
final int sign = (prec > 0.0) ? 1 : -1;
inc += weightTree * Math.abs(prec) * (current[i * P + k] - sign * current[j * P + l]);
}
}
numerator += inc;
}
double denominator = 2 * (Q - 1);
// denominator = 0.0; // TODO Remove
if (precision != null) {
denominator += weightTree * (2 * precision[ik][ik] + precisionStatistics[ik]);
}
next[i * P + k] = numerator / denominator;
}
}
// Force translation, rotation, reflection symmetry
EllipticalSliceOperator.transformPoint(next, true, true, P);
}
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public static final String MDS_STARTING_VALUES = "mdsModeFinder";
public static final String TOLERANCE = "tolerance";
public String getParserName() {
return MDS_STARTING_VALUES;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
MultiDimensionalScalingLikelihood likelihood =
(MultiDimensionalScalingLikelihood) xo.getChild(MultiDimensionalScalingLikelihood.class);
GaussianProcessRandomGenerator gp =
(GaussianProcessRandomGenerator) xo.getChild(GaussianProcessRandomGenerator.class);
double tolerance = xo.getAttribute(TOLERANCE, 1E-3);
MultiDimensionalScalingMM mm = new MultiDimensionalScalingMM(likelihood, gp, tolerance);
mm.run();
return mm;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides a mode finder for a MDS model on a tree";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(MultiDimensionalScalingLikelihood.class),
new ElementRule(GaussianProcessRandomGenerator.class, true),
AttributeRule.newDoubleRule(TOLERANCE, true),
};
public Class getReturnType() {
return MMAlgorithm.class;
}
};
private double[][] precision = null;
private double[] precisionStatistics = null;
// private int[] precisionSign = null;
private boolean ignoreGP = false;
private double weightTree; // TODO Formally compute
}