/* * RelaxedDriftModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.branchratemodel; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; import dr.evomodel.tree.TreeParameterModel; import dr.evomodel.tree.randomlocalmodel.RandomLocalTreeVariable; import dr.evomodelxml.branchratemodel.RelaxedDriftModelParser; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; /** * Created by IntelliJ IDEA. * User: mandevgill * Date: 7/25/14 * Time: 4:27 PM * To change this template use File | Settings | File Templates. */ public class RelaxedDriftModel extends AbstractBranchRateModel implements RandomLocalTreeVariable { public RelaxedDriftModel(TreeModel treeModel, Parameter rateIndicatorParameter, Parameter ratesParameter, Parameter driftRates) { super(RelaxedDriftModelParser.RELAXED_DRIFT); rates = new TreeParameterModel(treeModel, ratesParameter, true); ratesParameter.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, -Double.MAX_VALUE, ratesParameter.getDimension())); indicators = new TreeParameterModel(treeModel, rateIndicatorParameter, true); rateIndicatorParameter.addBounds(new Parameter.DefaultBounds(1, -1, rateIndicatorParameter.getDimension())); for (int i = 0; i < rateIndicatorParameter.getDimension(); i++) { rateIndicatorParameter.setParameterValue(i, 0.0); } for (int i = 0; i < ratesParameter.getDimension(); i++) { ratesParameter.setParameterValue(i, 0.0); } addModel(treeModel); this.treeModel = treeModel; addModel(indicators); addModel(rates); if (driftRates != null) { this.driftRates = driftRates; driftRates.setDimension(ratesParameter.getDimension()); } else { driftRates = null; } branchRates = new double[treeModel.getNodeCount()]; // Logger.getLogger("dr.evomodel").info(" indicator parameter name is '" + rateIndicatorParameter.getId() + "'"); calculateBranchRates(treeModel); //recalculateScaleFactor(); } /** * @param tree the tree * @param node the node to retrieve the variable of * @return the raw real-valued variable at this node */ public final double getVariable(Tree tree, NodeRef node) { return rates.getNodeValue(tree, node); } /** * @param tree the tree * @param node the node * @return true of the variable at this node is included in function, thus representing a change in the * function looking down the tree. */ public final boolean isVariableSelected(Tree tree, NodeRef node) { return indicators.getNodeValue(tree, node) != 0; } public void handleModelChangedEvent(Model model, Object object, int index) { recalculationNeeded = true; fireModelChanged(); } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { recalculationNeeded = true; fireModelChanged(); } protected void storeState() { } protected void restoreState() { calculateBranchRates(treeModel); // recalculateScaleFactor(); } protected void acceptState() { } public double getBranchRate(final Tree tree, final NodeRef node) { if (recalculationNeeded) { calculateBranchRates(treeModel); // recalculateScaleFactor(); recalculationNeeded = false; } return branchRates[node.getNumber()]; } private void calculateBranchRates(TreeModel tree) { branchRates[tree.getRoot().getNumber()] = getVariable(tree, tree.getRoot()); if (driftRates != null) { driftRates.setParameterValue(tree.getRoot().getNumber(), getVariable(tree, tree.getRoot())); } cbr(tree, tree.getRoot(), branchRates[tree.getRoot().getNumber()]); } /** * This is a recursive function that does the work of * calculating the unscaled branch rates across the tree * taking into account the indicator variables. * * @param tree the tree * @param node the node * @param rate the rate of the parent node */ private void cbr(TreeModel tree, NodeRef node, double rate) { NodeRef childNode0 = tree.getChild(node, 0); NodeRef childNode1 = tree.getChild(node, 1); int nodeNumber0 = childNode0.getNumber(); int nodeNumber1 = childNode1.getNumber(); double nodeIndicator = indicators.getNodeValue(tree, node); if (indicators.getNodeValue(tree, node) != 0) { // System.err.println("nodeIndicator: " + nodeIndicator); } if (nodeIndicator < 0) { // System.err.println("child0 change"); branchRates[nodeNumber0] = rate + getVariable(tree, childNode0); // branchRates[nodeNumber0] = getVariable(tree, childNode0); branchRates[nodeNumber1] = rate; if (driftRates != null) { driftRates.setParameterValue(nodeNumber0, rate + getVariable(tree, childNode0)); driftRates.setParameterValue(nodeNumber1, rate); } } else if (nodeIndicator > 0) { // System.err.println("child1 change"); branchRates[nodeNumber0] = rate; branchRates[nodeNumber1] = rate + getVariable(tree, childNode1); // branchRates[nodeNumber1] = getVariable(tree, childNode1); if (driftRates != null) { driftRates.setParameterValue(nodeNumber0, rate); driftRates.setParameterValue(nodeNumber1, rate + getVariable(tree, childNode1)); } } else { // System.err.println("NO CHANGES!!!"); branchRates[nodeNumber0] = rate; branchRates[nodeNumber1] = rate; if (driftRates != null) { driftRates.setParameterValue(nodeNumber0, rate); driftRates.setParameterValue(nodeNumber1, rate); } } if (tree.getChildCount(childNode0) > 0) { cbr(tree, childNode0, branchRates[nodeNumber0]); } if (tree.getChildCount(childNode1) > 0) { cbr(tree, childNode1, branchRates[nodeNumber1]); } /* if (indicators.getNodeValue(tree, childNode0) > 0.5 && indicators.getNodeValue(tree, childNode1) < 0.5) { branchRates[nodeNumber0] = rate + getVariable(tree, childNode0); // branchRates[nodeNumber0] = getVariable(tree, childNode0); branchRates[nodeNumber1] = rate; } else if (indicators.getNodeValue(tree, childNode0) < 0.5 && indicators.getNodeValue(tree, childNode1) > 0.5) { branchRates[nodeNumber0] = rate; branchRates[nodeNumber1] = rate + getVariable(tree, childNode1); // branchRates[nodeNumber1] = getVariable(tree, childNode1); } else { branchRates[nodeNumber0] = rate; branchRates[nodeNumber1] = rate; } if (tree.getChildCount(childNode0) > 0) { cbr(tree, childNode0, branchRates[nodeNumber0]); } if (tree.getChildCount(childNode1) > 0) { cbr(tree, childNode1, branchRates[nodeNumber1]); } */ } // the tree model private TreeModel treeModel; // the unscaled rates of each branch, taking into account the indicators private double[] branchRates; private TreeParameterModel indicators; private TreeParameterModel rates; private Parameter driftRates; boolean recalculationNeeded = true; }