/* * GaussianProcessSkytrackLikelihood.java * * Copyright (c) 2002-2011 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.coalescent; //import com.lowagie.text.Paragraph; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evomodel.tree.TreeModel; import dr.evomodelxml.coalescent.GaussianProcessSkytrackLikelihoodParser; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.MathUtils; import no.uib.cipr.matrix.*; import java.util.ArrayList; import java.util.List; // //import dr.evolution.tree.NodeRef; //import dr.evolution.tree.Tree; //import dr.evomodel.tree.TreeModel; //import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; //import dr.inference.model.MatrixParameter; //import dr.inference.model.Parameter; //import dr.inference.model.Variable; //import dr.math.MathUtils; //import no.uib.cipr.matrix.DenseVector; //import no.uib.cipr.matrix.NotConvergedException; //import no.uib.cipr.matrix.SymmTridiagEVD; //import no.uib.cipr.matrix.SymmTridiagMatrix; // //import java.util.ArrayList; //import java.util.List; /** * @author Vladimir Minin * @author Marc Suchard * @author Julia Palacios * @author Mandev */ public class GaussianProcessSkytrackLikelihood extends OldAbstractCoalescentLikelihood { // protected Parameter groupSizeParameter; public static final double LOG_TWO_TIMES_PI = 1.837877; protected Parameter precisionParameter; protected Parameter lambda_boundParameter; protected Parameter numGridPoints; protected Parameter lambdaParameter; //prior for lambda_bound, will be used in operators only protected Parameter betaParameter; protected Parameter alphaParameter; // Those that do not change in size - fixed per tree -hence need to store/restore protected Parameter popSizeParameter; //before called GPvalues protected Parameter changePoints; // protected double [] GPchangePoints; // protected double [] storedGPchangePoints; protected double [] GPcoalfactor; protected double [] storedGPcoalfactor; protected double [] coalfactor; protected double [] storedcoalfactor; protected int [] GPcounts; //It changes values, no need to storage protected int [] storedGPcounts; protected int numintervals; protected int numcoalpoints; protected double constlik; // Those that change size, they are initialized per tree, no need to store them // use as Parameter since they will be changing by operators // protected Parameter GPtimepoints; //tree + latent // protected double GPintervalkey; // membership that links with those that do not change in size // protected Parameter GPcoalfactor2; // choose(k,2) depending on membership protected int[] GPtype; // 1 if observed, -1 if latent protected int[] storedGPtype; // public double[] GPvalues; //may need to change type: Parameter? didn't know how to work with it protected double logGPLikelihood; protected double storedLogGPLikelihood; protected SymmTridiagMatrix weightMatrix; //this now changes in dimension, no need to storage // protected MatrixParameter dMatrix; protected boolean rescaleByRootHeight; private static List<Tree> wrapTree(Tree tree) { List<Tree> treeList = new ArrayList<Tree>(); treeList.add(tree); return treeList; } public GaussianProcessSkytrackLikelihood(Tree tree, Parameter precParameter, boolean rescaleByRootHeight, Parameter numGridPoints, Parameter lambda_bound, Parameter lambda_parameter, Parameter popParameter, Parameter alpha_parameter, Parameter beta_parameter, Parameter change_points) { this(wrapTree(tree), precParameter, rescaleByRootHeight, numGridPoints, lambda_bound, lambda_parameter, popParameter, alpha_parameter, beta_parameter, change_points); } public GaussianProcessSkytrackLikelihood(String name) { super(name); } public GaussianProcessSkytrackLikelihood(List<Tree> treeList, Parameter precParameter, boolean rescaleByRootHeight, Parameter numGridPoints, Parameter lambda_bound, Parameter lambda_parameter, Parameter popParameter, Parameter alpha_parameter, Parameter beta_parameter, Parameter change_points) { super(GaussianProcessSkytrackLikelihoodParser.SKYTRACK_LIKELIHOOD); this.popSizeParameter = popParameter; this.changePoints=change_points; // this.groupSizeParameter = groupParameter; this.precisionParameter = precParameter; this.lambdaParameter = lambda_parameter; this.betaParameter = beta_parameter; this.alphaParameter=alpha_parameter; // this.dMatrix = dMatrix; this.rescaleByRootHeight = rescaleByRootHeight; this.numGridPoints = numGridPoints; this.lambda_boundParameter= lambda_bound; // addVariable(GPvalues); addVariable(precisionParameter); // addVariable(lambdaParameter); // addVariable(lambda_boundParameter); // if (betaParameter != null) { // addVariable(betaParameter); // } setTree(treeList); wrapSetupIntervals(); // intervalCount = the size for constant vectors // int fieldLength = getCorrectFieldLength(); numintervals= getIntervalCount(); numcoalpoints=getCorrectFieldLength(); GPcoalfactor = new double[numintervals]; storedGPcoalfactor = new double[numintervals]; GPcounts = new int[numintervals]; storedGPcounts= new int[numintervals]; GPtype=new int[numcoalpoints]; storedGPtype = new int[numcoalpoints]; popSizeParameter.setDimension(numcoalpoints); changePoints.setDimension(numcoalpoints); coalfactor= new double[numcoalpoints]; storedcoalfactor= new double[numcoalpoints]; initializationReport(); setupSufficientStatistics(); setupGPvalues(); } // Methods that override existent methods protected void setTree(List<Tree> treeList) { if (treeList.size() != 1) { throw new RuntimeException("GP-based method only implemented for one tree"); } this.tree = treeList.get(0); this.treesSet = null; if (tree instanceof TreeModel) { addModel((TreeModel) tree); } } protected void wrapSetupIntervals() { setupIntervals(); } // // protected int getCorrectFieldLength() { // return tree.getExternalNodeCount() - 1; // } // // //I will use specific input public double calculateLogLikelihood(Parameter Gfunction, int[] latentCounts, int [] eventType, Parameter upper_Bound, double [] Gfactor) { double upperBound = upper_Bound.getParameterValue(0); logGPLikelihood=-upperBound*getConstlik(); for (int i=0; i<latentCounts.length; i++){ if (Gfactor[i]>0) { logGPLikelihood+=latentCounts[i]*Math.log(upperBound*Gfactor[i]); } } double[] currentGfunction = Gfunction.getParameterValues(); for (int i=0; i<Gfunction.getSize();i++){ logGPLikelihood+= -Math.log(1+Math.exp(-eventType[i]*currentGfunction[i])); } return logGPLikelihood; } // protected double calculateLogCoalescentLikelihood() { // // if (!intervalsKnown) { // // intervalsKnown -> false when handleModelChanged event occurs in super. // wrapSetupIntervals(); // setupGMRFWeights(); // intervalsKnown = true; // } // // // Matrix operations taken from block update sampler to calculate data likelihood and field prior // // double currentLike = 0; // double[] currentGamma = popSizeParameter.getParameterValues(); // // for (int i = 0; i < fieldLength; i++) { // currentLike += -currentGamma[i] - sufficientStatistics[i] * Math.exp(-currentGamma[i]); // } // // return currentLike;// + LogNormalDistribution.logPdf(Math.exp(popSizeParameter.getParameterValue(coalescentIntervals.length - 1)), mu, sigma); // return 0.0; // } // public double getConstlik(){ return constlik; } public double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = calculateLogLikelihood(popSizeParameter,GPcounts,GPtype,lambda_boundParameter,GPcoalfactor)+calculateLogGP(); likelihoodKnown = true; } return logLikelihood; // return 0.0; } protected double calculateLogGP() { if (!intervalsKnown) { // intervalsKnown -> false when handleModelChanged event occurs in super. wrapSetupIntervals(); setupQmatrix(precisionParameter.getParameterValue(0)); intervalsKnown = true; } double currentLike = 0; DenseVector diagonal1 = new DenseVector(numcoalpoints); DenseVector currentGamma = new DenseVector(popSizeParameter.getParameterValues()); SymmTridiagMatrix currentQ = weightMatrix; currentQ.mult(currentGamma, diagonal1); currentLike = -0.5 * logGeneralizedDeterminant(currentQ) - 0.5 * currentGamma.dot(diagonal1) - 0.5 * (numcoalpoints - 1) * LOG_TWO_TIMES_PI; return currentLike; } //log pseudo-determinant public static double logGeneralizedDeterminant(SymmTridiagMatrix X) { //Set up the eigenvalue solver SymmTridiagEVD eigen = new SymmTridiagEVD(X.numRows(), false); //Solve for the eigenvalues try { eigen.factor(X); } catch (NotConvergedException e) { throw new RuntimeException("Not converged error in generalized determinate calculation.\n" + e.getMessage()); } //Get the eigenvalues double[] x = eigen.getEigenvalues(); double a = 0; for (double d : x) { if (d > 0.00001) a += Math.log(d); } return a; } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type){ likelihoodKnown = false; // Parameters (precision and popsizes do not change intervals or GMRF Q matrix (I DON'T UNDERSTAND THIS) } protected void restoreState() { super.restoreState(); System.arraycopy(storedcoalfactor, 0, coalfactor, 0, storedcoalfactor.length); System.arraycopy(storedGPtype,0,GPtype,0,storedGPtype.length); System.arraycopy(storedGPcoalfactor,0,GPcoalfactor,0,storedGPcoalfactor.length); System.arraycopy(storedGPcounts,0,GPcounts,0,storedGPcounts.length); // weightMatrix = storedWeightMatrix; logGPLikelihood = storedLogGPLikelihood; } protected void storeState() { super.storeState(); System.arraycopy(GPtype, 0, storedGPtype, 0, GPtype.length); System.arraycopy(GPcoalfactor,0,storedGPcoalfactor,0,GPcoalfactor.length); System.arraycopy(coalfactor, 0, storedcoalfactor, 0, coalfactor.length); System.arraycopy(GPcounts, 0, storedGPcounts,0,GPcounts.length); // storedWeightMatrix = weightMatrix.copy(); storedLogGPLikelihood = logGPLikelihood; } // I don't understand this public String toString() { return getId() + "(" + Double.toString(getLogLikelihood()) + ")"; } //// private final Parameter latentPoints; // // private final Parameter lambda_bound; // // // public void initializationReport() { System.out.println("Creating a GP based estimation of effective population trajectories:"); // System.out.println("\tPopulation sizes: " + popSizeParameter.getDimension()); System.out.println("\tIf you publish results using this model, please reference: Minin, Palacios, Suchard (XXXX), AAA"); } public static void checkTree(TreeModel treeModel) { // todo Should only be run if there exists a zero-length interval // TreeModel treeModel = (TreeModel) tree; for (int i = 0; i < treeModel.getInternalNodeCount(); i++) { NodeRef node = treeModel.getInternalNode(i); if (node != treeModel.getRoot()) { double parentHeight = treeModel.getNodeHeight(treeModel.getParent(node)); double childHeight0 = treeModel.getNodeHeight(treeModel.getChild(node, 0)); double childHeight1 = treeModel.getNodeHeight(treeModel.getChild(node, 1)); double maxChild = childHeight0; if (childHeight1 > maxChild) maxChild = childHeight1; double newHeight = maxChild + MathUtils.nextDouble() * (parentHeight - maxChild); treeModel.setNodeHeight(node, newHeight); } } treeModel.pushTreeChangedEvent(); } //Sufficient Statistics for GP - coal+sampling protected void setupSufficientStatistics() { double length = 0.0; int countcoal = 0; constlik= 0; for (int i = 0; i < getIntervalCount(); i++) { length += getInterval(i); GPcounts[i]=0; GPcoalfactor[i] =getLineageCount(i)*(getLineageCount(i)-1) / 2.0; constlik+=GPcoalfactor[i]*getInterval(i); if (getIntervalType(i) == CoalescentEventType.COALESCENT) { GPcounts[i]=1; GPtype[countcoal]=1; changePoints.setParameterValue(countcoal,length); coalfactor[countcoal]=getLineageCount(i)*(getLineageCount(i)-1)/2.0; countcoal++; } } } protected int getCorrectFieldLength() { return tree.getExternalNodeCount() - 1; } protected void setupQmatrix(double precision) { //Set up the weight Matrix double trick=0.000001; double[] offdiag = new double[getCorrectFieldLength() - 1]; double[] diag = new double[getCorrectFieldLength()]; for (int i = 0; i < getCorrectFieldLength() - 1; i++) { offdiag[i] = precision*(-1.0 / (changePoints.getParameterValue(i+1)-changePoints.getParameterValue(i))); if (i<getCorrectFieldLength()-2){ diag[i+1]= -offdiag[i]+precision*(1.0/(changePoints.getParameterValue(i+2)-changePoints.getParameterValue(i+1))+trick); } } // Diffuse prior correction - intrinsic //Take care of the endpoints diag[0] = -offdiag[0]+precision*trick; diag[getCorrectFieldLength() - 1] = -offdiag[getCorrectFieldLength() - 2]+precision*(trick); weightMatrix = new SymmTridiagMatrix(diag, offdiag); } protected void setupGPvalues() { setupQmatrix(precisionParameter.getParameterValue(0)); int length = getCorrectFieldLength(); DenseVector StandNorm = new DenseVector(length); DenseVector MultiNorm = new DenseVector(length); for (int i=0; i<length;i++){ StandNorm.set(i,MathUtils.nextGaussian()); // StandNorm.set(i,0.1); } UpperSPDBandMatrix Qcurrent = new UpperSPDBandMatrix(weightMatrix, 1); BandCholesky U = new BandCholesky(length,1,true); U.factor(Qcurrent); UpperTriangBandMatrix CholeskyUpper = U.getU(); CholeskyUpper.solve(StandNorm,MultiNorm); for (int i=0; i<length;i++){ popSizeParameter.setParameterValue(i,MultiNorm.get(i)); } } public Parameter getPrecisionParameter() { return precisionParameter; } public Parameter getPopSizeParameter() { return popSizeParameter; } public Parameter getLambdaParameter() { return lambdaParameter; } public Parameter getLambdaBoundParameter() { return lambda_boundParameter; } public Parameter getChangePoints() { return changePoints; } public double getAlphaParameter(){ return alphaParameter.getParameterValue(0); } public double getBetaParameter(){ return betaParameter.getParameterValue(0); } public double [] getGPcoalfactor(){ return GPcoalfactor; } public double [] getcoalfactor(){ return coalfactor; } public int [] getGPtype(){ return GPtype; } public int [] getGPcounts(){ return GPcounts; } public SymmTridiagMatrix getWeightMatrix() { return weightMatrix.copy(); } // Methods needed for GP-based }