package dr.evomodel.antigenic;
import dr.inference.model.*;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.xml.*;
import dr.math.distributions.ExponentialDistribution;
import java.util.ArrayList;
import java.util.List;
/**
* @author Andrew Rambaut
* @author Trevor Bedford
* @author Marc Suchard
* @version $Id$
*/
// jumpVector is a series of year-to-year changes in AG1. Start of at 0.
// jumpMean and jumpSd give the parameters of the gamma distribution of jump sizes
public class AntigenicJumpPrior extends AbstractModelLikelihood implements Citable {
public final static String ANTIGENIC_JUMP_PRIOR = "antigenicJumpPrior";
public AntigenicJumpPrior(
MatrixParameter locationsParameter,
Parameter datesParameter,
Parameter jumpVectorParameter,
Parameter jumpMeanParameter,
Parameter locationPrecisionParameter
) {
super(ANTIGENIC_JUMP_PRIOR);
this.locationsParameter = locationsParameter;
addVariable(this.locationsParameter);
this.datesParameter = datesParameter;
addVariable(this.datesParameter);
dimension = locationsParameter.getParameter(0).getDimension();
count = locationsParameter.getParameterCount();
this.jumpMeanParameter = jumpMeanParameter;
addVariable(jumpMeanParameter);
jumpMeanParameter.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, 0.0, 1));
this.locationPrecisionParameter = locationPrecisionParameter;
addVariable(locationPrecisionParameter);
locationPrecisionParameter.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, 0.0, 1));
likelihoodKnown = false;
this.jumpVectorParameter = jumpVectorParameter;
addVariable(this.jumpVectorParameter);
jumpVectorParameter.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, 0.0, 1));
earliestDate = (int) datesParameter.getParameterValue(0);
for (int i=0; i<count; i++) {
int date = (int) datesParameter.getParameterValue(i);
if (earliestDate > date) {
earliestDate = date;
}
}
latestDate = (int) datesParameter.getParameterValue(0);
for (int i=0; i<count; i++) {
int date = (int) datesParameter.getParameterValue(i);
if (latestDate < date) {
latestDate = date;
}
}
List<String> jumpNames = new ArrayList<String>();
for (int i = earliestDate; i < latestDate; i++) {
jumpNames.add(Integer.toString(i));
}
jumpVectorParameter.setDimension(jumpNames.size());
String[] labelArray = new String[jumpNames.size()];
jumpNames.toArray(labelArray);
jumpVectorParameter.setDimensionNames(labelArray);
for (int i = 0; i < jumpNames.size(); i++) {
jumpVectorParameter.setParameterValue(i, 1);
}
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
if (variable == locationsParameter || variable == datesParameter
|| variable == jumpVectorParameter || variable == jumpMeanParameter
|| variable == locationPrecisionParameter) {
likelihoodKnown = false;
}
}
@Override
protected void storeState() {
storedLogLikelihood = logLikelihood;
}
@Override
protected void restoreState() {
logLikelihood = storedLogLikelihood;
likelihoodKnown = false;
}
@Override
protected void acceptState() {
}
@Override
public Model getModel() {
return this;
}
@Override
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = computeLogLikelihood();
}
return logLikelihood;
}
private double computeLogLikelihood() {
double logLikelihood = 0;
logLikelihood += jumpLogLikelihood();
logLikelihood += locationLogLikelihood();
likelihoodKnown = true;
return logLikelihood;
}
// log probability of observing jump vector given jump mean and jump sd
protected double jumpLogLikelihood() {
double logLikelihood = 0;
for (int i=0; i < latestDate - earliestDate - 1; i++) {
double x = jumpVectorParameter.getParameterValue(i);
double lambda = 1 / jumpMeanParameter.getParameterValue(0);
logLikelihood += ExponentialDistribution.logPdf(x, lambda);
}
return logLikelihood;
}
// log probability of observing virus locations given jump vector
protected double locationLogLikelihood() {
// go through each location and compute sum of squared residuals from regression line
double ssr = 0.0;
for (int i=0; i < count; i++) {
Parameter loc = locationsParameter.getParameter(i);
int date = (int) datesParameter.getParameterValue(i);
double x = loc.getParameterValue(0);
double y = expectationFromDate(date);
ssr += (x - y) * (x - y);
for (int j=1; j < dimension; j++) {
x = loc.getParameterValue(j);
ssr += x*x;
}
}
// compute likelihood from SSR
double precision = locationPrecisionParameter.getParameterValue(0);
double logLikelihood = (0.5 * Math.log(precision) * count) - (0.5 * precision * ssr);
return logLikelihood;
}
// calculate the expected AG1 position of a particular date
protected double expectationFromDate(int date) {
int index = date - earliestDate - 1;
double exp = 0;
if (index >= 0) {
for (int i=0; i < index; i++) {
exp += jumpVectorParameter.getParameterValue(index);
}
}
return exp;
}
@Override
public void makeDirty() {
likelihoodKnown = false;
}
private final int dimension;
private final int count;
private final Parameter datesParameter;
private final MatrixParameter locationsParameter;
private final Parameter jumpVectorParameter;
private final Parameter jumpMeanParameter;
private final Parameter locationPrecisionParameter;
private int earliestDate;
private int latestDate;
private double logLikelihood = 0.0;
private double storedLogLikelihood = 0.0;
private boolean likelihoodKnown = false;
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String LOCATIONS = "locations";
public final static String DATES = "dates";
public final static String JUMPVECTOR = "jumpVector";
public final static String JUMPMEAN = "jumpMean";
public final static String LOCATIONPRECISION = "locationPrecision";
public String getParserName() {
return ANTIGENIC_JUMP_PRIOR;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
MatrixParameter locationsParameter = (MatrixParameter) xo.getElementFirstChild(LOCATIONS);
Parameter datesParameter = (Parameter) xo.getElementFirstChild(DATES);
Parameter jumpVectorParameter = (Parameter) xo.getElementFirstChild(JUMPVECTOR);
Parameter jumpMeanParameter = (Parameter) xo.getElementFirstChild(JUMPMEAN);
Parameter locationPrecisionParameter = (Parameter) xo.getElementFirstChild(LOCATIONPRECISION);
AntigenicJumpPrior AGDP = new AntigenicJumpPrior(
locationsParameter,
datesParameter,
jumpVectorParameter,
jumpMeanParameter,
locationPrecisionParameter);
// Logger.getLogger("dr.evomodel").info("Using EvolutionaryCartography model. Please cite:\n" + Utils.getCitationString(AGL));
return AGDP;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides the likelihood of a vector of coordinates in some multidimensional 'antigenic' space based on an expected relationship with time.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(LOCATIONS, MatrixParameter.class),
new ElementRule(DATES, Parameter.class),
new ElementRule(JUMPVECTOR, Parameter.class),
new ElementRule(JUMPMEAN, Parameter.class),
new ElementRule(LOCATIONPRECISION, Parameter.class)
};
public Class getReturnType() {
return ContinuousAntigenicTraitLikelihood.class;
}
};
public List<Citation> getCitations() {
List<Citation> citations = new ArrayList<Citation>();
citations.add(new Citation(
new Author[]{
new Author("T", "Bedford"),
new Author("MA", "Suchard"),
new Author("P", "Lemey"),
new Author("G", "Dudas"),
new Author("C", "Russell"),
new Author("D", "Smith"),
new Author("A", "Rambaut")
},
Citation.Status.IN_PREPARATION
));
return citations;
}
}