/* * TopologyTracer.java * * Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evolution.tree; import dr.evolution.io.Importer; import dr.evolution.io.NewickImporter; import dr.util.Pair; import java.io.IOException; import java.util.*; /** * @author Guy Baele * Path difference metric according to Kendall & Colijn (2015) */ public class KCPathDifferenceMetric { private Tree focalTree; private int dim; private double[] focalSmallM, focalLargeM; public KCPathDifferenceMetric() { } public KCPathDifferenceMetric(Tree focalTree) { this.focalTree = focalTree; //this.dim = (externalNodeCount-2)*(externalNodeCount-1)+externalNodeCount; this.dim = focalTree.getExternalNodeCount() * focalTree.getExternalNodeCount(); this.focalSmallM = new double[dim]; this.focalLargeM = new double[dim]; traverse(focalTree, focalTree.getRoot(), 0.0, 0, focalLargeM, focalSmallM); } public List<Double> getMetric(Tree tree, ArrayList<Double> lambda) { checkTreeTaxa(focalTree, tree); double[] smallMTwo = new double[dim]; double[] largeMTwo = new double[dim]; traverse(tree, tree.getRoot(), 0.0, 0, largeMTwo, smallMTwo); List<Double> results = new ArrayList<Double>(); int n = tree.getExternalNodeCount(); for (Double l : lambda) { results.add(calculateMetric(focalSmallM, focalLargeM, smallMTwo, largeMTwo, n, l)); } return results; } /** * This method bypasses the constructor entirely, computing the metric on the two provided trees * and ignoring the internally stored tree. * @param tree1 Focal tree that will be used for computing the metric * @param tree2 Provided tree that will be compared to the focal tree * @param lambda Collection of lambda values for which to compute the metric * @return */ public List<Double> getMetric(Tree tree1, Tree tree2, ArrayList<Double> lambda) { checkTreeTaxa(tree1, tree2); int dim = tree1.getExternalNodeCount() * tree1.getExternalNodeCount(); double[] smallMOne = new double[dim]; double[] largeMOne = new double[dim]; double[] smallMTwo = new double[dim]; double[] largeMTwo = new double[dim]; traverse(tree1, tree1.getRoot(), 0.0, 0, largeMOne, smallMOne); traverse(tree2, tree2.getRoot(), 0.0, 0, largeMTwo, smallMTwo); List<Double> results = new ArrayList<Double>(); int n = tree1.getExternalNodeCount(); for (Double l : lambda) { results.add(calculateMetric(smallMOne, largeMOne, smallMTwo, largeMTwo, n, l)); } return results; } private double calculateMetric(double[] smallMOne, double[] largeMOne, double[] smallMTwo, double[] largeMTwo, int n, double l) { double distance = 0.0; //calculate Euclidean distance for this lambda value int k = 0; for (int i = 0; i < n; i++) { for (int j = i; j < n; j++) { // include diagonal int index = (i * n) + j; double vOne = (1.0 - l) * smallMOne[index] + l * largeMOne[index]; double vTwo = (1.0 - l) * smallMTwo[index] + l * largeMTwo[index]; distance += Math.pow(vOne - vTwo, 2); } } return Math.sqrt(distance); } private void checkTreeTaxa(Tree tree1, Tree tree2) { //check if taxon lists are in the same order!! if (tree1.getExternalNodeCount() != tree2.getExternalNodeCount()) { throw new RuntimeException("Different number of taxa in both trees."); } else { for (int i = 0; i < tree1.getExternalNodeCount(); i++) { if (!tree1.getNodeTaxon(tree1.getExternalNode(i)).getId().equals(tree2.getNodeTaxon(tree2.getExternalNode(i)).getId())) { throw new RuntimeException("Mismatch between taxa in both trees: " + tree1.getNodeTaxon(tree1.getExternalNode(i)).getId() + " vs. " + tree2.getNodeTaxon(tree2.getExternalNode(i)).getId()); } } } } private Set<NodeRef> traverse(Tree tree, NodeRef node, double lengthFromRoot, int edgesFromRoot, double[] lengths, double[] edges) { NodeRef left = tree.getChild(node, 0); NodeRef right = tree.getChild(node, 1); Set<NodeRef> leftSet = null; Set<NodeRef> rightSet = null; if (!tree.isExternal(left)) { leftSet = traverse(tree, left, lengthFromRoot + tree.getBranchLength(left), edgesFromRoot + 1, lengths, edges); } else { leftSet = Collections.singleton(left); int index = (left.getNumber() * tree.getExternalNodeCount()) + left.getNumber(); lengths[index] = tree.getBranchLength(left); edges[index] = 1; } if (!tree.isExternal(right)) { rightSet = traverse(tree, right, lengthFromRoot + tree.getBranchLength(right), edgesFromRoot + 1, lengths, edges); } else { rightSet = Collections.singleton(right); int index = (right.getNumber() * tree.getExternalNodeCount()) + right.getNumber(); lengths[index] = tree.getBranchLength(right); edges[index] = 1; } for (NodeRef tip1 : leftSet) { for (NodeRef tip2 : rightSet) { int index; if (tip1.getNumber() < tip2.getNumber()) { index = (tip1.getNumber() * tree.getExternalNodeCount()) + tip2.getNumber(); } else { index = (tip2.getNumber() * tree.getExternalNodeCount()) + tip1.getNumber(); } lengths[index] = lengthFromRoot; edges[index] = edgesFromRoot; } } Set<NodeRef> tips = new HashSet<NodeRef>(); tips.addAll(leftSet); tips.addAll(rightSet); return tips; } @Deprecated public ArrayList<Double> getMetric_old(Tree tree, ArrayList<Double> lambda) { //check if taxon lists are in the same order!! if (focalTree.getExternalNodeCount() != tree.getExternalNodeCount()) { throw new RuntimeException("Different number of taxa in both trees."); } else { for (int i = 0; i < focalTree.getExternalNodeCount(); i++) { if (!focalTree.getNodeTaxon(focalTree.getExternalNode(i)).getId().equals(tree.getNodeTaxon(tree.getExternalNode(i)).getId())) { throw new RuntimeException("Mismatch between taxa in both trees: " + focalTree.getNodeTaxon(focalTree.getExternalNode(i)).getId() + " vs. " + tree.getNodeTaxon(tree.getExternalNode(i)).getId()); } } } double[] smallMTwo = new double[dim]; double[] largeMTwo = new double[dim]; int index = 0; for (int i = 0; i < tree.getExternalNodeCount(); i++) { for (int j = i+1; j < tree.getExternalNodeCount(); j++) { //get two leaf nodes NodeRef nodeOne = tree.getExternalNode(i); NodeRef nodeTwo = tree.getExternalNode(j); //get common ancestor of 2 leaf nodes NodeRef MRCA = TreeUtils.getCommonAncestor(tree, nodeOne, nodeTwo); int edges = 0; double branchLengths = 0.0; while (MRCA != tree.getRoot()) { edges++; branchLengths += tree.getNodeHeight(tree.getParent(MRCA)) - tree.getNodeHeight(MRCA); MRCA = tree.getParent(MRCA); } smallMTwo[index] = edges; largeMTwo[index] = branchLengths; index++; } } int externalNodeCount = tree.getExternalNodeCount(); int d = (externalNodeCount-2)*(externalNodeCount-1)+externalNodeCount; //fill out arrays further index = 0; for (int i = (externalNodeCount-1)*(externalNodeCount-2); i < d; i++) { smallMTwo[i] = 1.0; largeMTwo[i] = tree.getNodeHeight(tree.getParent(tree.getExternalNode(index))) - tree.getNodeHeight(tree.getExternalNode(index)); index++; } double[] vArrayOne = new double[dim]; double[] vArrayTwo = new double[dim]; ArrayList<Double> results = new ArrayList<Double>(); for (Double l : lambda) { double distance = 0.0; //calculate Euclidean distance for this lambda value for (int i = 0; i < dim; i++) { vArrayOne[i] = (1.0 - l)*focalSmallM[i] + l*focalLargeM[i]; vArrayTwo[i] = (1.0 - l)*smallMTwo[i] + l*largeMTwo[i]; distance += Math.pow(vArrayOne[i] - vArrayTwo[i],2); } distance = Math.sqrt(distance); results.add(distance); } return results; } @Deprecated public ArrayList<Double> getMetric_old(Tree tree1, Tree tree2, ArrayList<Double> lambda) { int dim = (tree1.getExternalNodeCount()-2)*(tree1.getExternalNodeCount()-1)+tree1.getExternalNodeCount(); double[] smallMOne = new double[dim]; double[] largeMOne = new double[dim]; double[] smallMTwo = new double[dim]; double[] largeMTwo = new double[dim]; //check if taxon lists are in the same order!! if (tree1.getExternalNodeCount() != tree2.getExternalNodeCount()) { throw new RuntimeException("Different number of taxa in both trees."); } else { for (int i = 0; i < tree1.getExternalNodeCount(); i++) { if (!tree1.getNodeTaxon(tree1.getExternalNode(i)).getId().equals(tree2.getNodeTaxon(tree2.getExternalNode(i)).getId())) { throw new RuntimeException("Mismatch between taxa in both trees: " + tree1.getNodeTaxon(tree1.getExternalNode(i)).getId() + " vs. " + tree2.getNodeTaxon(tree2.getExternalNode(i)).getId()); } } } int index = 0; for (int i = 0; i < tree1.getExternalNodeCount(); i++) { for (int j = i+1; j < tree1.getExternalNodeCount(); j++) { //get two leaf nodes NodeRef nodeOne = tree1.getExternalNode(i); NodeRef nodeTwo = tree1.getExternalNode(j); //get common ancestor of 2 leaf nodes NodeRef MRCA = TreeUtils.getCommonAncestor(tree1, nodeOne, nodeTwo); int edges = 0; double branchLengths = 0.0; while (MRCA != tree1.getRoot()) { edges++; branchLengths += tree1.getNodeHeight(tree1.getParent(MRCA)) - tree1.getNodeHeight(MRCA); MRCA = tree1.getParent(MRCA); } smallMOne[index] = edges; largeMOne[index] = branchLengths; index++; } } int externalNodeCount = tree2.getExternalNodeCount(); int d = (externalNodeCount-2)*(externalNodeCount-1)+externalNodeCount; //fill out arrays further index = 0; for (int i = (externalNodeCount-1)*(externalNodeCount-2); i < d; i++) { smallMOne[i] = 1.0; largeMOne[i] = tree1.getNodeHeight(tree1.getParent(tree1.getExternalNode(index))) - tree1.getNodeHeight(tree1.getExternalNode(index)); index++; } /*for (int i = 0; i < smallMOne.length; i++) { System.out.print(smallMOne[i] + " "); } System.out.println(); for (int i = 0; i < largeMOne.length; i++) { System.out.print(largeMOne[i] + " "); } System.out.println("\n");*/ index = 0; for (int i = 0; i < tree2.getExternalNodeCount(); i++) { for (int j = i+1; j < tree2.getExternalNodeCount(); j++) { //get two leaf nodes NodeRef nodeOne = tree2.getExternalNode(i); NodeRef nodeTwo = tree2.getExternalNode(j); //get common ancestor of 2 leaf nodes NodeRef MRCA = TreeUtils.getCommonAncestor(tree2, nodeOne, nodeTwo); int edges = 0; double branchLengths = 0.0; while (MRCA != tree2.getRoot()) { edges++; branchLengths += tree2.getNodeHeight(tree2.getParent(MRCA)) - tree2.getNodeHeight(MRCA); MRCA = tree2.getParent(MRCA); } smallMTwo[index] = edges; largeMTwo[index] = branchLengths; index++; } } //fill out arrays further index = 0; for (int i = (externalNodeCount-1)*(externalNodeCount-2); i < d; i++) { smallMTwo[i] = 1.0; largeMTwo[i] = tree2.getNodeHeight(tree2.getParent(tree2.getExternalNode(index))) - tree2.getNodeHeight(tree2.getExternalNode(index)); index++; } /*for (int i = 0; i < smallMTwo.length; i++) { System.out.print(smallMTwo[i] + " "); } System.out.println(); for (int i = 0; i < largeMTwo.length; i++) { System.out.print(largeMTwo[i] + " "); } System.out.println("\n");*/ double[] vArrayOne = new double[dim]; double[] vArrayTwo = new double[dim]; ArrayList<Double> results = new ArrayList<Double>(); for (Double l : lambda) { double distance = 0.0; //calculate Euclidean distance for this lambda value for (int i = 0; i < dim; i++) { vArrayOne[i] = (1.0 - l)*smallMOne[i] + l*largeMOne[i]; vArrayTwo[i] = (1.0 - l)*smallMTwo[i] + l*largeMTwo[i]; distance += Math.pow(vArrayOne[i] - vArrayTwo[i],2); } distance = Math.sqrt(distance); results.add(distance); } return results; } public static void main(String[] args) { // tree 1: ((A:1.2,B:0.8):0.5,(C:0.8,D:1.0):1.1); // tree 2: (((A:0.8,B:1.3999999999999997):0.30000000000000004,C:0.7000000000000002):0.8999999999999999,D:1.0); // // lambda (0.0) = 2.0 // lambda (0.5) = 1.9397164741270823 // lambda (1.0) = 1.962141687034858 // lambda (0.0) = 2.0 // lambda (0.5) = 1.9397164741270823 // lambda (1.0) = 1.962141687034858 try { //4-taxa example NewickImporter importer = new NewickImporter("(('A':1.2,'B':0.8):0.5,('C':0.8,'D':1.0):1.1)"); Tree treeOne = importer.importNextTree(); System.out.println("4-taxa tree 1: " + treeOne); importer = new NewickImporter("((('A':0.8,'B':1.4):0.3,'C':0.7):0.9,'D':1.0)"); Tree treeTwo = importer.importNextTree(); System.out.println("4-taxa tree 2: " + treeTwo + "\n"); ArrayList<Double> lambdaValues = new ArrayList<Double>(); lambdaValues.add(0.0); lambdaValues.add(0.5); lambdaValues.add(1.0); List<Double> metric = (new KCPathDifferenceMetric().getMetric(treeOne, treeTwo, lambdaValues)); List<Double> metric_old = (new KCPathDifferenceMetric().getMetric_old(treeOne, treeTwo, lambdaValues)); System.out.println("\nPaired trees:"); System.out.println("lambda (0.0) = " + metric.get(0) + " old = " + metric_old.get(0)); System.out.println("lambda (0.5) = " + metric.get(1) + " old = " + metric_old.get(1)); System.out.println("lambda (1.0) = " + metric.get(2) + " old = " + metric_old.get(2)); //Additional test for comparing a collection of trees against a (fixed) focal tree metric = new KCPathDifferenceMetric(treeOne).getMetric(treeTwo, lambdaValues); metric_old = new KCPathDifferenceMetric(treeOne).getMetric_old(treeTwo, lambdaValues); System.out.println("\nFocal trees:"); System.out.println("lambda (0.0) = " + metric.get(0) + " old = " + metric_old.get(0)); System.out.println("lambda (0.5) = " + metric.get(1) + " old = " + metric_old.get(1)); System.out.println("lambda (1.0) = " + metric.get(2) + " old = " + metric_old.get(2)); //5-taxa example importer = new NewickImporter("(((('A':0.6,'B':0.6):0.1,'C':0.5):0.4,'D':0.7):0.1,'E':1.3)"); treeOne = importer.importNextTree(); System.out.println("5-taxa tree 1: " + treeOne); importer = new NewickImporter("((('A':0.8,'B':1.4):0.1,'C':0.7):0.2,('D':1.0,'E':0.9):1.3)"); treeTwo = importer.importNextTree(); System.out.println("5-taxa tree 2: " + treeTwo + "\n"); //lambda = 0.0 should yield: sqrt(7) = 2.6457513110645907162 //lambda = 1.0 should yield: sqrt(2.96) = 1.7204650534085252911 lambdaValues = new ArrayList<Double>(); lambdaValues.add(0.0); lambdaValues.add(0.5); lambdaValues.add(1.0); metric = (new KCPathDifferenceMetric().getMetric(treeOne, treeTwo, lambdaValues)); System.out.println("\nPaired trees:"); System.out.println("lambda (0.0) = " + metric.get(0) + " old = " + metric_old.get(0)); System.out.println("lambda (0.5) = " + metric.get(1) + " old = " + metric_old.get(1)); System.out.println("lambda (1.0) = " + metric.get(2) + " old = " + metric_old.get(2)); //Additional test for comparing a collection of trees against a (fixed) focal tree metric = new KCPathDifferenceMetric(treeOne).getMetric(treeTwo, lambdaValues); System.out.println("\nFocal trees:"); System.out.println("lambda (0.0) = " + metric.get(0) + " old = " + metric_old.get(0)); System.out.println("lambda (0.5) = " + metric.get(1) + " old = " + metric_old.get(1)); System.out.println("lambda (1.0) = " + metric.get(2) + " old = " + metric_old.get(2)); //timings long startTime = System.currentTimeMillis(); for (int i = 0; i < 1000000; i++) { new KCPathDifferenceMetric().getMetric_old(treeOne, treeTwo, lambdaValues); } System.out.println("Old algorithm: " + (System.currentTimeMillis() - startTime) + " ms"); startTime = System.currentTimeMillis(); for (int i = 0; i < 1000000; i++) { new KCPathDifferenceMetric().getMetric(treeOne, treeTwo, lambdaValues); } System.out.println("New algorithm: " + (System.currentTimeMillis() - startTime) + " ms"); } catch(Importer.ImportException ie) { System.err.println(ie); } catch(IOException ioe) { System.err.println(ioe); } } }