/*
* MultiDimensionalScalingLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.multidimensionalscaling;
import dr.evomodel.antigenic.MultidimensionalScalingLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.*;
import dr.util.DataTable;
import dr.xml.*;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
/**
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id$
*/
public class MultiDimensionalScalingLikelihood extends AbstractModelLikelihood implements Reportable,
GradientWrtParameterProvider {
public static final String REQUIRED_FLAGS_PROPERTY = "mds.required.flags";
@Override
public String getReport() {
StringBuilder sb = new StringBuilder();
sb.append(getId() + ": " + getLogLikelihood());
return sb.toString();
}
@Override
public Likelihood getLikelihood() {
return this;
}
@Override
public Parameter getParameter() {
return locationsParameter;
}
@Override
public int getDimension() {
return locationsParameter.getDimension();
}
@Override
public double[] getGradientLogDensity() {
// TODO Cache !!!
if (gradient == null) {
gradient = new double[locationsParameter.getDimension()];
}
mdsCore.getGradient(gradient);
return gradient; // TODO Do not expose internals
}
public enum ObservationType {
POINT,
UPPER_BOUND,
LOWER_BOUND,
MISSING
}
public final static String MULTIDIMENSIONAL_SCALING_LIKELIHOOD = "multiDimensionalScalingLikelihood";
// public MultiDimensionalScalingLikelihood(
// int mdsDimension,
// Parameter mdsPrecision,
// MatrixParameter locationsParameter,
// DataTable<double[]> dataTable,
// boolean reorderData) {
// this(mdsDimension, mdsPrecision, locationsParameter, dataTable, false, reorderData);
// }
/**
* A simple constructor for a fully specified symmetrical data matrix
* @param mdsDimension
* @param mdsPrecision
* @param locationsParameter
* @param dataTable
* @param isLeftTruncated
* @param reorderData
*/
public MultiDimensionalScalingLikelihood(
int mdsDimension,
Parameter mdsPrecision,
MatrixParameterInterface locationsParameter,
DataTable<double[]> dataTable,
boolean isLeftTruncated,
boolean reorderData) {
super(MULTIDIMENSIONAL_SCALING_LIKELIHOOD);
this.mdsDimension = mdsDimension;
this.isLeftTruncated = isLeftTruncated;
// construct a compact data table
String[] rowLabelsOriginal = dataTable.getRowLabels();
// String[] columnLabels = dataTable.getRowLabels();
int rowCount = dataTable.getRowCount();
locationCount = rowCount;
boolean allowMissing = true;
int[] permute;
if (reorderData) {
permute = getPermutation(rowLabelsOriginal, locationsParameter, allowMissing);
} else {
permute = new int[locationCount];
for (int i = 0; i < locationCount; ++i) {
permute[i] = i; // identity
}
}
String[] rowLabels = new String[locationCount];
int observationCount = rowCount * rowCount;
// double[] observations = new double[observationCount];
observations = new double[observationCount];
ObservationType[] observationTypes = new ObservationType[observationCount];
double[][] tmp = new double[rowCount][rowCount];
for (int i = 0; i < rowCount; i++) {
rowLabels[i] = rowLabelsOriginal[permute[i]];
double[] dataRow = dataTable.getRow(permute[i]);
for (int j = i + 1; j < rowCount; j++) {
tmp[i][j] = tmp[j][i] = dataRow[permute[j]];
}
}
int u = 0;
for (int i = 0; i < rowCount; i++) {
for (int j = 0; j < rowCount; j++) {
observations[u] = (i == j ? 0 : tmp[i][j]);
observationTypes[u] = ObservationType.POINT;
u++;
}
}
initialize(mdsDimension, mdsPrecision, isLeftTruncated, locationsParameter,
rowLabels, observations, observationTypes);
}
// private class Data {
// int observationCount;
// double[] observations;
// ObservationType[] observationTypes;
//
// Data(int observationCount, double[] observations, ObservationType[] observationTypes) {
// this.observationCount = observationCount;
// this.observations = observations;
// this.observationTypes = observationTypes;
// }
// }
public double[] getObservations() { return observations; } // TODO Grab from core when needed to save space
public MatrixParameterInterface getMatrixParameter() { return locationsParameter; }
private int[] getPermutation(String[] source, MatrixParameterInterface destination, boolean allowMissing) {
if (source.length != destination.getColumnDimension()) {
throw new IllegalArgumentException("Dimension mismatch");
}
final int length = source.length;
Map<String,Integer> map = new HashMap<String, Integer>(destination.getColumnDimension());
for (int i = 0; i < length; ++i) {
map.put(source[i],i);
}
int[] permute = new int[length];
for (int i = 0; i < length; ++i) {
Integer p = map.get(destination.getParameter(i).getParameterName());
if (p == null) {
if (allowMissing) {
Logger.getLogger("dr.app.beagle").info("Missing label!!!");
} else {
throw new IllegalArgumentException("Missing label");
}
} else {
permute[i] = p;
}
}
return permute;
}
private MultiDimensionalScalingCore getCore() {
long computeMode = 0;
String r = System.getProperty(REQUIRED_FLAGS_PROPERTY);
if (r != null) {
computeMode = Long.parseLong(r.trim());
}
MultiDimensionalScalingCore core;
if (computeMode >= MultiDimensionalScalingCore.USE_NATIVE_MDS) {
System.err.println("Attempting to use a native MDS core with flag: " + computeMode + "; may the force be with you ....");
core = new MassivelyParallelMDSImpl();
flags = computeMode;
} else {
core = new MultiDimensionalScalingCoreImpl();
}
return core;
}
public int getMdsDimension() { return mdsDimension; }
public int getLocationCount() { return locationCount; }
protected void initialize(
final int mdsDimension,
final Parameter mdsPrecision,
final boolean isLeftTruncated,
final MatrixParameterInterface locationsParameter,
final String[] locationLabels,
final double[] observations,
final ObservationType[] observationTypes) {
this.mdsCore = getCore();
if (isLeftTruncated) {
flags |= MultiDimensionalScalingCore.LEFT_TRUNCATION;
}
System.err.println("Initializing with flags: " + flags);
this.mdsCore.initialize(mdsDimension, locationCount, flags);
this.locationLabels = locationLabels;
this.locationsParameter = locationsParameter;
setupLocationsParameter(this.locationsParameter);
addVariable(locationsParameter);
this.mdsPrecisionParameter = mdsPrecision;
addVariable(mdsPrecision);
mdsCore.setParameters(mdsPrecisionParameter.getParameterValues());
mdsCore.setPairwiseData(observations);
// for (int i = 0; i < locationCount; i++) {
// mdsCore.updateLocation(i, locationsParameter.getColumnValues(i));
// }
mdsCore.updateLocation(-1, locationsParameter.getParameterValues());
// make sure everything is calculated on first evaluation
makeDirty();
}
protected void setupLocationsParameter(MatrixParameterInterface locationsParameter) {
final boolean exisitingParameter = locationsParameter.getColumnDimension() > 0;
if (exisitingParameter){
if (locationsParameter.getColumnDimension() != locationCount){
throw new RuntimeException("locationsParameter column dimension ("+locationsParameter.getColumnDimension()+") is not equal to the locationCount ("+locationCount+")");
}
if (locationsParameter.getRowDimension() != mdsDimension){
throw new RuntimeException("locationsParameter row dimension ("+locationsParameter.getRowDimension()+") is not equal to the mdsDimension ("+mdsDimension+")");
}
} else {
// locationsParameter.setColumnDimension(mdsDimension);
// locationsParameter.setRowDimension(locationCount);
throw new IllegalArgumentException("Dimensions on matrix must be set");
}
for (int i = 0; i < locationLabels.length; i++) {
if (exisitingParameter) {
if (locationsParameter.getParameter(i).getParameterName().compareTo(locationLabels[i]) != 0) {
throw new RuntimeException("Mismatched trait parameter name (" + locationsParameter.getParameter(i).getParameterName() +
") and data dimension name (" + locationLabels[i] + ")");
}
} else {
locationsParameter.getParameter(i).setId(locationLabels[i]);
}
}
for (int i = 0; i < locationsParameter.getColumnDimension(); ++i) {
Parameter param = locationsParameter.getParameter(i);
try {
if (param.getBounds() != null) {
// Do nothing
}
} catch (NullPointerException exception) {
param.addBounds(new Parameter.DefaultBounds(
Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, param.getDimension()));
}
}
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) { }
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
// TODO Flag which cachedDistances or mdsPrecision need to be updated
if (variable == locationsParameter) {
if (index == -1) {
mdsCore.updateLocation(-1, locationsParameter.getParameterValues());
} else {
int locationIndex = index / mdsDimension;
mdsCore.updateLocation(locationIndex, locationsParameter.getColumnValues(locationIndex));
}
} else if (variable == mdsPrecisionParameter) {
mdsCore.setParameters(mdsPrecisionParameter.getParameterValues());
} else {
// could be a derived class's parameter
// throw new IllegalArgumentException("Unknown parameter");
}
likelihoodKnown = false;
}
@Override
protected void storeState() {
storedLogLikelihood = logLikelihood;
mdsCore.storeState();
}
@Override
protected void restoreState() {
logLikelihood = storedLogLikelihood;
likelihoodKnown = true;
mdsCore.restoreState();
}
@Override
protected void acceptState() {
mdsCore.acceptState();
// do nothing
}
public void makeDirty() {
likelihoodKnown = false;
mdsCore.makeDirty();
}
public Model getModel() {
return this;
}
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = mdsCore.calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String FILE_NAME = "fileName";
public final static String TIP_TRAIT = "tipTrait";
public final static String LOCATIONS = "locations";
public static final String MDS_DIMENSION = "mdsDimension";
public static final String MDS_PRECISION = "mdsPrecision";
public static final String INCLUDE_TRUNCATION = "includeTruncation";
public static final String USE_OLD = "useOld";
public static final String FORCE_REORDER = "forceReorder";
public String getParserName() {
return MULTIDIMENSIONAL_SCALING_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String fileName = xo.getStringAttribute(FILE_NAME);
DataTable<double[]> distanceTable;
try {
distanceTable = DataTable.Double.parse(new FileReader(fileName));
} catch (IOException e) {
throw new XMLParseException("Unable to read assay data from file: " + e.getMessage());
}
if (distanceTable.getRowCount() != distanceTable.getColumnCount()) {
throw new XMLParseException("Data table is not symmetrical.");
}
int mdsDimension = xo.getIntegerAttribute(MDS_DIMENSION);
MatrixParameterInterface locationsParameter = (MatrixParameterInterface) xo.getElementFirstChild(LOCATIONS);
Parameter mdsPrecision = (Parameter) xo.getElementFirstChild(MDS_PRECISION);
boolean useOld = xo.getAttribute(USE_OLD, false);
boolean includeTrauncation = xo.getAttribute(INCLUDE_TRUNCATION, false);
boolean forceReorder = xo.getAttribute(FORCE_REORDER, false);
if (useOld) {
System.err.println("USE OLD");
return new MultidimensionalScalingLikelihood(mdsDimension, includeTrauncation, mdsPrecision, (MatrixParameter)locationsParameter, distanceTable);
} else {
return new MultiDimensionalScalingLikelihood(mdsDimension, mdsPrecision, locationsParameter,
distanceTable, includeTrauncation, forceReorder);
}
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides the likelihood of pairwise distance given vectors of coordinates" +
"for points according to the multidimensional scaling scheme of XXX & Rafferty (xxxx).";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newStringRule(FILE_NAME, false, "The name of the file containing the assay table"),
AttributeRule.newIntegerRule(MDS_DIMENSION, false, "The dimension of the space for MDS"),
new ElementRule(LOCATIONS, MatrixParameterInterface.class),
AttributeRule.newBooleanRule(USE_OLD, true),
AttributeRule.newBooleanRule(INCLUDE_TRUNCATION, true),
AttributeRule.newBooleanRule(FORCE_REORDER, true),
new ElementRule(MDS_PRECISION, Parameter.class)
};
public Class getReturnType() {
return MultiDimensionalScalingLikelihood.class;
}
};
public double getMDSPrecision() {
return mdsPrecisionParameter.getParameterValue(0);
}
private final int mdsDimension;
private final int locationCount;
private final boolean isLeftTruncated;
private MultiDimensionalScalingCore mdsCore;
private String[] locationLabels;
private Parameter mdsPrecisionParameter;
private MatrixParameterInterface locationsParameter;
private boolean likelihoodKnown = false;
private double logLikelihood;
private double storedLogLikelihood;
private long flags = 0;
private double[] observations;
private double[] gradient;
}