/* * MVOUCovarianceOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators; import dr.inference.model.MatrixParameter; import dr.inference.model.Parameter; import dr.inferencexml.operators.MVOUCovarianceOperatorParser; import dr.math.distributions.WishartDistribution; import dr.math.matrixAlgebra.Matrix; /** * @author Marc Suchard */ public class MVOUCovarianceOperator extends AbstractCoercableOperator { private double mixingFactor; private MatrixParameter varMatrix; private int dim; private MatrixParameter precisionParam; private WishartDistribution priorDistribution; private int priorDf; private double[][] I; private Matrix Iinv; public MVOUCovarianceOperator(double mixingFactor, MatrixParameter varMatrix, int priorDf, double weight, CoercionMode mode) { super(mode); this.mixingFactor = mixingFactor; this.varMatrix = varMatrix; this.priorDf = priorDf; setWeight(weight); dim = varMatrix.getColumnDimension(); I = new double[dim][dim]; for (int i = 0; i < dim; i++) I[i][i] = 1.0; // I[i][i] = i; Iinv = new Matrix(I).inverse(); } public double doOperation() { double[][] draw = WishartDistribution.nextWishart(priorDf, I); // double[][] good = varMatrix.getParameterAsMatrix(); // double[][] saveOld = varMatrix.getParameterAsMatrix(); // System.err.println("draw:\n"+new Matrix(draw)); double[][] oldValue = varMatrix.getParameterAsMatrix(); for (int i = 0; i < dim; i++) { Parameter column = varMatrix.getParameter(i); for (int j = 0; j < dim; j++) column.setParameterValue(j, mixingFactor * oldValue[j][i] + (1.0 - mixingFactor) * draw[j][i] ); } // varMatrix.fireParameterChangedEvent(); // calculate Hastings ratio // System.err.println("oldValue:\n"+new Matrix(oldValue).toString()); // System.err.println("newValue:\n"+new Matrix(varMatrix.getParameterAsMatrix()).toString()); Matrix forwardDrawMatrix = new Matrix(draw); for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { // saveOld[i][j] *= - mixingFactor; // saveOld[i][j] += varMatrix.getParameterValue(i,j); // saveOld[i][j] /= 1.0 - mixingFactor; oldValue[i][j] -= mixingFactor * varMatrix.getParameterValue(i, j); oldValue[i][j] /= 1.0 - mixingFactor; } } // double[][] saveNew = varMatrix.getParameterAsMatrix(); Matrix backwardDrawMatrix = new Matrix(oldValue); // System.err.println("forward:\n"+forwardDrawMatrix); // System.err.println("backward:\n"+backwardDrawMatrix); // System.err.println("calc start"); // if( Math.abs(backwardDrawMatrix.component(0,0) + 0.251) < 0.001 ) { // System.err.println("found:\n"+backwardDrawMatrix); // // System.err.println("original:\n"+new Matrix(good)); // System.err.println("draw:\n"+new Matrix(draw)); // System.err.println("proposed:\n"+new Matrix(varMatrix.getParameterAsMatrix())); // System.err.println("mixing = "+mixingFactor); // System.err.println("back[0][0] = "+backwardDrawMatrix.component(0,0)); // System.err.println("saveOld[0][0] = "+saveOld[0][0]); // // // } double bProb = WishartDistribution.logPdf(backwardDrawMatrix, Iinv, priorDf, dim, // WishartDistribution.computeNormalizationConstant(Iinv,priorDf,dim)); 0); if (bProb == Double.NEGATIVE_INFINITY) { // throw new OperatorFailedException("Not reversible"); // not clear if this means a HR of -Inf or a RuntimeException return Double.NEGATIVE_INFINITY; } double fProb = WishartDistribution.logPdf(forwardDrawMatrix, Iinv, priorDf, dim, // WishartDistribution.computeNormalizationConstant(Iinv,priorDf,dim)); 0); // System.err.println("calc end"); // if( fProb == Double.NEGATIVE_INFINITY ) { // System.err.println("forwards is problem"); // System.exit(-1); // } // if( bProb == Double.NEGATIVE_INFINITY ) { // System.err.println("backwards is problem"); // System.exit(-1); // } // System.err.println("fProb = "+fProb); // System.err.println("bProb = "+bProb); // System.exit(-1); return bProb - fProb; } //MCMCOperator INTERFACE public final String getOperatorName() { return MVOUCovarianceOperatorParser.MVOU_OPERATOR + "(" + varMatrix.getId() + ")"; } public double getCoercableParameter() { return Math.log(mixingFactor / (1.0 - mixingFactor)); // return Math.log((1.0 - mixingFactor) / mixingFactor); } public void setCoercableParameter(double value) { mixingFactor = Math.exp(value) / (1.0 + Math.exp(value)); // mixingFactor = Math.exp(-value) / (1.0 + Math.exp(-value)); } public double getRawParameter() { return mixingFactor; } public double getMixingFactor() { return mixingFactor; } public double getTargetAcceptanceProbability() { return 0.234; } public double getMinimumAcceptanceLevel() { return 0.1; } public double getMaximumAcceptanceLevel() { return 0.4; } public double getMinimumGoodAcceptanceLevel() { return 0.20; } public double getMaximumGoodAcceptanceLevel() { return 0.30; } public final String getPerformanceSuggestion() { double prob = MCMCOperator.Utils.getAcceptanceProbability(this); double targetProb = getTargetAcceptanceProbability(); dr.util.NumberFormatter formatter = new dr.util.NumberFormatter(5); double sf = OperatorUtils.optimizeWindowSize(mixingFactor, prob, targetProb); if (prob < getMinimumGoodAcceptanceLevel()) { return "Try setting mixingFactor to about " + formatter.format(sf); } else if (prob > getMaximumGoodAcceptanceLevel()) { return "Try setting mixingFactor to about " + formatter.format(sf); } else return ""; } }