/* * ImportancePruneAndRegraft.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /** * */ package dr.evomodel.operators; import dr.evolution.tree.MutableTree.InvalidTreeException; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.ConditionalCladeFrequency; import dr.evomodel.tree.TreeModel; import dr.evomodelxml.operators.ImportancePruneAndRegraftParser; import dr.inference.operators.*; import dr.math.MathUtils; import java.util.ArrayList; import java.util.List; /** * @author Sebastian Hoehna */ // Cleaning out untouched stuff. Can be resurrected if needed @Deprecated public class ImportancePruneAndRegraft extends AbstractTreeOperator { public final int SAMPLE_EVERY = 10; private final TreeModel tree; private final int samples; private int sampleCount; private boolean burnin = false; private final ConditionalCladeFrequency probabilityEstimater; private final OperatorSchedule schedule; /** * */ public ImportancePruneAndRegraft(TreeModel tree, double weight, int samples, int epsilon) { this.tree = tree; setWeight(weight); this.samples = samples; sampleCount = 0; probabilityEstimater = new ConditionalCladeFrequency(tree, epsilon); schedule = getOperatorSchedule(tree); } /** * */ public ImportancePruneAndRegraft(TreeModel tree, double weight, int samples) { this.tree = tree; setWeight(weight); this.samples = samples; sampleCount = 0; // double epsilon = 1 - Math.pow(0.5, samples); double epsilon = 1 - Math.pow(0.5, 1.0 / samples); // double epsilon = 1; probabilityEstimater = new ConditionalCladeFrequency(tree, epsilon); schedule = getOperatorSchedule(tree); } private OperatorSchedule getOperatorSchedule(TreeModel treeModel) { ExchangeOperator narrowExchange = new ExchangeOperator( ExchangeOperator.NARROW, treeModel, 10); ExchangeOperator wideExchange = new ExchangeOperator( ExchangeOperator.WIDE, treeModel, 3); SubtreeSlideOperator subtreeSlide = new SubtreeSlideOperator(treeModel, 10.0, 1.0, true, false, false, false, CoercionMode.COERCION_ON); NNI nni = new NNI(treeModel, 10.0); WilsonBalding wilsonBalding = new WilsonBalding(treeModel, 3.0); FNPR fnpr = new FNPR(treeModel, 5.0); OperatorSchedule schedule = new SimpleOperatorSchedule(); schedule.addOperator(narrowExchange); schedule.addOperator(wideExchange); schedule.addOperator(subtreeSlide); schedule.addOperator(nni); schedule.addOperator(wilsonBalding); schedule.addOperator(fnpr); return schedule; } /* * (non-Javadoc) * * @see dr.inference.operators.SimpleMCMCOperator#doOperation() */ @Override public double doOperation() { if (!burnin) { if (sampleCount < samples * SAMPLE_EVERY) { sampleCount++; if (sampleCount % SAMPLE_EVERY == 0) { probabilityEstimater.addTree(tree); } setAcceptCount(0); setRejectCount(0); setTransitions(0); return doUnguidedOperation(); } else { return importancePruneAndRegraft(); } } else { return doUnguidedOperation(); } } private double doUnguidedOperation() { int index = schedule.getNextOperatorIndex(); SimpleMCMCOperator operator = (SimpleMCMCOperator) schedule.getOperator(index); return operator.doOperation(); } private double importancePruneAndRegraft() { final int nodeCount = tree.getNodeCount(); final NodeRef root = tree.getRoot(); NodeRef i; do { int indexI = MathUtils.nextInt(nodeCount); i = tree.getNode(indexI); } while (root == i || tree.getParent(i) == root); List<Integer> secondNodeIndices = new ArrayList<Integer>(); List<Double> probabilities = new ArrayList<Double>(); NodeRef j, iP, jP; iP = tree.getParent(i); double iParentHeight = tree.getNodeHeight(iP); double sum = 0.0; double backwardLikelihood = calculateTreeProbability(tree); int offset = (int) -backwardLikelihood; double backward = Math.exp(backwardLikelihood + offset); final NodeRef oldBrother = getOtherChild(tree, iP, i); final NodeRef oldGrandfather = tree.getParent(iP); tree.beginTreeEdit(); for (int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if (j != root) { jP = tree.getParent(j); if ((iP != jP) && (tree.getNodeHeight(j) < iParentHeight && iParentHeight < tree .getNodeHeight(jP))) { secondNodeIndices.add(n); pruneAndRegraft(tree, i, iP, j, jP); double prob = Math.exp(calculateTreeProbability(tree) + offset); probabilities.add(prob); sum += prob; pruneAndRegraft(tree, i, iP, oldBrother, oldGrandfather); } } } double ran = Math.random() * sum; int index = 0; while (ran > 0.0) { ran -= probabilities.get(index); index++; } index--; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); if (iP != jP) { pruneAndRegraft(tree, i, iP, j, jP); tree.pushTreeChangedEvent(i); } tree.endTreeEdit(); // AR - not sure whether this check is necessary try { tree.checkTreeIsValid(); } catch (InvalidTreeException e) { throw new RuntimeException(e.getMessage()); } double forward = probabilities.get(index); // tree.pushTreeChangedEvent(jP); // tree.pushTreeChangedEvent(oldGrandfather); tree.pushTreeChangedEvent(i); double forwardProb = (forward / sum); double backwardProb = (backward / (sum - forward + backward)); final double hastingsRatio = Math.log(backwardProb / forwardProb); return hastingsRatio; } private void pruneAndRegraft(TreeModel tree, NodeRef i, NodeRef iP, NodeRef j, NodeRef jP) { // tree.beginTreeEdit(); // the grandfather NodeRef iG = tree.getParent(iP); // the brother NodeRef iB = getOtherChild(tree, iP, i); // prune tree.removeChild(iP, iB); tree.removeChild(iG, iP); tree.addChild(iG, iB); // reattach tree.removeChild(jP, j); tree.addChild(iP, j); tree.addChild(jP, iP); } private double calculateTreeProbability(Tree tree) { return probabilityEstimater.getTreeProbability(tree); } public void setBurnin(boolean burnin) { this.burnin = burnin; } /* * (non-Javadoc) * * @see dr.inference.operators.SimpleMCMCOperator#getOperatorName() */ @Override public String getOperatorName() { return ImportancePruneAndRegraftParser.IMPORTANCE_PRUNE_AND_REGRAFT; } /* * (non-Javadoc) * * @see dr.inference.operators.MCMCOperator#getPerformanceSuggestion() */ public String getPerformanceSuggestion() { // TODO Auto-generated method stub return ""; } }