/* * ImportanceNarrowExchange.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /** * */ package dr.evomodel.operators; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.tree.NodeRef; import dr.evolution.tree.TreeUtils; import dr.evolution.util.Taxon; import dr.evomodel.tree.TreeLogger; import dr.evomodel.tree.TreeModel; import dr.math.MathUtils; import java.util.HashMap; import java.util.Map; /** * @author Joseph Heled * @version 1.0 */ @SuppressWarnings({"ConstantConditions"}) // Cleaning out untouched stuff. Can be resurrected if needed @Deprecated public class ImportanceNarrowExchange extends AbstractTreeOperator implements TreeLogger.LogUpon { private TreeModel tree = null; private final double epsilon; private int[] nodeCounts; private boolean justAccepted; private final double[] weights; private double totalWeight; public ImportanceNarrowExchange(TreeModel tree, PatternList patterns, double epsilon, double weight) throws Exception { this.tree = tree; setWeight(weight); justAccepted = false; this.epsilon = epsilon; weights = new double[tree.getNodeCount()]; setTaxaWeights(patterns); } private void setTaxaWeights(PatternList patterns) throws Exception { final DataType type = patterns.getDataType(); Map<Integer, Integer> counts = new HashMap<Integer, Integer>(); int[] taxaCounts = new int[patterns.getPatternLength()]; for(int nPat = 0; nPat < patterns.getPatternCount(); ++nPat) { final int[] pattern = patterns.getPattern(nPat); counts.clear(); for( int s : pattern ) { if( type.isGapState(s) || type.isAmbiguousState(s) || type.isUnknownState(s) ) { continue; } if( ! counts.containsKey(s) ) { counts.put(s, 0); } counts.put(s, counts.get(s)+1); } if( counts.size() <= 1 ) { continue; } Map.Entry<Integer, Integer> m = null; for( Map.Entry<Integer, Integer> e : counts.entrySet()) { if( m == null || e.getValue() > m.getValue() ) { m = e; } } assert m != null; for(int i = 0; i < pattern.length; ++i) { final int s = pattern[i]; if( ! (type.isGapState(s) || type.isAmbiguousState(s) || type.isUnknownState(s) ) ) { if( s != m.getKey() ) { taxaCounts[i] += patterns.getPatternWeight(nPat); } } } } nodeCounts = new int[tree.getNodeCount()]; Map<Taxon, Integer> taxaWeights = new HashMap<Taxon, Integer>(); for(int i = 0; i < taxaCounts.length; ++i) { taxaWeights.put(patterns.getTaxon(i), taxaCounts[i]); } for(int i = 0; i < tree.getExternalNodeCount(); ++i) { final NodeRef leaf = tree.getExternalNode(i); final Taxon nodeTaxon = tree.getNodeTaxon(leaf); // assert taxaWeights.containsKey(nodeTaxon) : nodeTaxon; if( ! taxaWeights.containsKey(nodeTaxon) ) { throw new Exception("" + nodeTaxon + " in tree " + tree.getId() + " not in patterns" + patterns.getId() + "."); } nodeCounts[leaf.getNumber()] = taxaWeights.get(nodeTaxon) ; } } private int traverseTree(NodeRef n) { final int k = n.getNumber(); if( ! tree.isExternal(n) ) { int w = 0; for(int nc = 0; nc < tree.getChildCount(n); ++nc ) { w += traverseTree(tree.getChild(n, nc)); } nodeCounts[k] = w; } return nodeCounts[k]; } final private int DEBUG = 0; private double nodeWeight(final NodeRef node ) { final NodeRef ch0 = tree.getChild(node, 0); final NodeRef ch1 = tree.getChild(node, 1); if( tree.isExternal(ch0) && tree.isExternal(ch1) ) { return 0; } final boolean leftSubtree = tree.getNodeHeight(ch0) < tree.getNodeHeight(ch1); final int st0 = nodeCounts[(leftSubtree ? ch0 : ch1).getNumber()]; final int st1 = nodeCounts[tree.getChild(leftSubtree ? ch1 : ch0, 0).getNumber()]; final int st2 = nodeCounts[tree.getChild(leftSubtree ? ch1 : ch0, 1).getNumber()]; final double w = (epsilon + st0)*(epsilon + st1) + (epsilon + st0)*(epsilon + st2) + (epsilon + st1)*(epsilon + st2) - 3*epsilon*epsilon; return w; } private int getNode() { traverseTree(tree.getRoot()); totalWeight = 0; for(int k = 0; k < tree.getInternalNodeCount(); ++k) { final NodeRef node = tree.getInternalNode(k); final double w = nodeWeight(node); weights[node.getNumber()] = w; if( DEBUG > 5 && w > 0 ) { System.out.println("" + w + " " + TreeUtils.uniqueNewick(tree, node)); } totalWeight += w; } double r = MathUtils.nextDouble() * totalWeight; for(int k = 0; k < tree.getInternalNodeCount(); ++k) { final NodeRef node = tree.getInternalNode(k); final int nodeIndex = node.getNumber(); r -= weights[nodeIndex]; if( r < 0 ) { if( DEBUG > 0 ) { System.out.println("" + weights[nodeIndex] + "/" + totalWeight + " " + TreeUtils.uniqueNewick(tree, node)); } return k; } } //assert false; return -1; } /* * (non-Javadoc) * * @see dr.inference.operators.SimpleMCMCOperator#doOperation() */ @Override public double doOperation() { int k = getNode(); if( k < 0 ) { throw new RuntimeException("no node found"); } final NodeRef p = tree.getInternalNode(k); if( DEBUG > 0 ) { System.out.println(TreeUtils.newick(tree)); System.out.println("" + getAcceptCount() + " - " + getRejectCount()); } assert tree.getChildCount(p) == 2; final NodeRef ch0 = tree.getChild(p, 0); final NodeRef ch1 = tree.getChild(p, 1); final boolean side = tree.getNodeHeight(ch0) < tree.getNodeHeight(ch1); final NodeRef iUncle = side ? ch0 : ch1; final NodeRef jP = side ? ch1 : ch0; final NodeRef j = tree.getChild(jP, MathUtils.nextInt(2)); exchangeNodes(tree, iUncle, j, p, jP); final int jPindex = jP.getNumber(); nodeCounts[jPindex] += -nodeCounts[j.getNumber()] + nodeCounts[iUncle.getNumber()]; // debug // the weights function is symmetric with respect to counts of the three sub-trees involved in the exchange, // so the weight of the root node (p) does not change, but the weight of jP has changed final double prev = weights[jPindex]; // The counts below jP are still valid final double now = nodeWeight(jP); double newTot = totalWeight + (now - prev); weights[jPindex] = now; // debug final NodeRef pP = tree.getParent(p); if( pP != null ) { final int pPindex = pP.getNumber(); final double prev1 = weights[pPindex]; final double now1 = nodeWeight(pP); weights[pPindex] = now1; // debug newTot += + (now1 - prev1); } // pr(node before operation) = w/tot. pr(node after operation) = w/newTot // log(back/forward) = w/newTot / w/tot = log(tot/newTot) double saveTotalWeight = totalWeight; // debug double[] w = new double[weights.length]; // debug System.arraycopy(weights, 0, w, 0, w.length); // debug int[] c = new int[nodeCounts.length]; // debug System.arraycopy(nodeCounts, 0, c, 0, c.length); // debug getNode(); // debug for(int l = 0; l < c.length; ++l) { // debug if( c[l] != nodeCounts[l] ) { // debug assert false; // debug } // debug } // debug for(int l = 0; l < w.length; ++l) { // debug if( Math.abs(weights[l]/w[l] - 1) > 1e-12 ) { // debug assert false; // debug } // debug } // debug assert Math.abs(newTot/totalWeight - 1) < 1e-10; // debug return Math.log(saveTotalWeight / newTot); } public void reject() { super.reject(); justAccepted = false; } public void accept(double deviation) { super.accept(deviation); justAccepted = true; } private final long lFreq = 1000; private long lastLog = -lFreq-1; public boolean logNow(long state) { boolean r = justAccepted; if( lastLog + lFreq >= state ) { r = false; } else if( r ) { lastLog = state; } justAccepted = false; return r; } /* * (non-Javadoc) * * @see dr.inference.operators.SimpleMCMCOperator#getOperatorName() */ @Override public String getOperatorName() { return "Importance Narrow Exchange" + "(" + tree.getId() + ")"; } public double getMinimumAcceptanceLevel() { return 0.025; } public double getMinimumGoodAcceptanceLevel() { return 0.05; } /* * (non-Javadoc) * * @see dr.inference.operators.MCMCOperator#getPerformanceSuggestion() */ public String getPerformanceSuggestion() { return ""; } }