/* * TransmissionTreeToVirusTree.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.tools; import dr.app.util.Arguments; import dr.evolution.coalescent.CoalescentSimulator; import dr.evolution.coalescent.ConstantPopulation; import dr.evolution.coalescent.DemographicFunction; import dr.evolution.coalescent.ExponentialGrowth; import dr.evolution.tree.*; import dr.evolution.util.Date; import dr.evolution.util.Taxon; import dr.evolution.util.Units; import dr.evomodel.epidemiology.LogisticGrowthN0; import java.io.*; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; /** * Simulated a virus tree given a transmission tree and dates of sampling * * @author mhall */ public class TransmissionTreeToVirusTree { protected static PrintStream progressStream = System.out; private enum ModelType{CONSTANT, EXPONENTIAL, LOGISTIC} public static final String HELP = "help"; public static final String DEMOGRAPHIC_MODEL = "demoModel"; public static final String[] demographics = {"Constant", "Exponential", "Logistic"}; public static final String STARTING_POPULATION_SIZE = "N0"; public static final String GROWTH_RATE = "growthRate"; public static final String T50 = "t50"; private DemographicFunction demFunct; private ArrayList<InfectedUnit> units; private HashMap<String, InfectedUnit> idMap; private String outputFileRoot; private double coalescentProbability; public TransmissionTreeToVirusTree(String fileName, DemographicFunction demFunct, String outputFileRoot){ this.demFunct = demFunct; units = new ArrayList<InfectedUnit>(); idMap = new HashMap<String, InfectedUnit>(); this.outputFileRoot = outputFileRoot; coalescentProbability = 1; try { readInfectionEvents(fileName); readSamplingEvents(fileName); } catch(IOException e){ e.printStackTrace(); } } public TransmissionTreeToVirusTree(String sampFileName, String transFileName, DemographicFunction demFunct, String outputFileRoot){ this.demFunct = demFunct; units = new ArrayList<InfectedUnit>(); idMap = new HashMap<String, InfectedUnit>(); this.outputFileRoot = outputFileRoot; try { readInfectionEvents(transFileName); readSamplingEvents(sampFileName); } catch(IOException e){ e.printStackTrace(); } } private enum EventType{ INFECTION, SAMPLE } private void run() throws IOException{ ArrayList<FlexibleTree> detailedTrees = makeTrees(); ArrayList<FlexibleTree> simpleTrees = new ArrayList<FlexibleTree>(); for(FlexibleTree tree : detailedTrees) { FlexibleTree wbTree = makeWellBehavedTree(tree); wbTree.setAttribute("firstCase", tree.getAttribute("firstCase")); simpleTrees.add(wbTree); } for(FlexibleTree tree: detailedTrees){ NexusExporter exporter = new NexusExporter(new PrintStream(outputFileRoot + tree.getAttribute("firstCase") + "_detailed.nex")); exporter.exportTree(tree); } for(FlexibleTree tree: simpleTrees){ NexusExporter exporter = new NexusExporter(new PrintStream(outputFileRoot + tree.getAttribute("firstCase") + "_simple.nex")); exporter.exportTree(tree); } } private void readInfectionEvents(String fileName) throws IOException{ BufferedReader reader = new BufferedReader(new FileReader(fileName)); ArrayList<String[]> keptLines = new ArrayList<String[]>(); reader.readLine(); String line = reader.readLine(); while(line!=null){ String[] entries = line.split(","); keptLines.add(entries); InfectedUnit unit = new InfectedUnit("ID_"+entries[1]); units.add(unit); idMap.put("ID_"+entries[1], unit); line = reader.readLine(); } for(String[] repeatLine: keptLines){ InfectedUnit infectee = idMap.get("ID_"+repeatLine[1]); if(!repeatLine[2].equals("-1")){ InfectedUnit infector = idMap.get("ID_"+repeatLine[2]); Event infection = new Event(EventType.INFECTION, Double.parseDouble(repeatLine[3]), infector, infectee); infector.addInfectionEvent(infection); infectee.setInfectionEvent(infection); infectee.parent = infector; } else { Event infection = new Event(EventType.INFECTION, Double.parseDouble(repeatLine[3]), null, infectee); infectee.setInfectionEvent(infection); } } } private void readSamplingEvents(String fileName) throws IOException{ BufferedReader reader = new BufferedReader(new FileReader(fileName)); reader.readLine(); String line = reader.readLine(); while(line!=null){ String[] entries = line.split(","); if(!entries[7].equals("NA")) { if (!idMap.containsKey("ID_"+entries[1])) { throw new RuntimeException("Trying to add a sampling event to unit " + entries[2] + " but this " + "unit not previously defined"); } InfectedUnit unit = idMap.get("ID_"+entries[1]); unit.addSamplingEvent(Double.parseDouble(entries[7])); } line = reader.readLine(); } } // events are only relevant if there is a sampling event somewhere further up the tree private FlexibleTree makeTreelet(InfectedUnit unit, ArrayList<Event> relevantEvents){ if(relevantEvents.size()==0){ return null; } ArrayList<SimpleNode> nodes = new ArrayList<SimpleNode>(); unit.sortEvents(); double lastRelevantEventTime = Double.NEGATIVE_INFINITY; for(Event event : relevantEvents){ if(event.time > lastRelevantEventTime){ lastRelevantEventTime = event.time; } } double activeTime = lastRelevantEventTime - unit.infectionEvent.time; for(Event event : relevantEvents){ Taxon taxon; if(event.type == EventType.INFECTION){ taxon = new Taxon(event.infectee.id+"_infected_by_"+event.infector.id+"_"+event.time); } else { taxon = new Taxon(unit.id+"_sampled_"+event.time); } taxon.setDate(new Date(event.time - unit.infectionEvent.time, Units.Type.YEARS, false)); SimpleNode node = new SimpleNode(); node.setTaxon(taxon); nodes.add(node); node.setHeight(unit.infectionEvent.time - event.time); node.setAttribute("Event", event); } FlexibleNode treeletRoot; if(nodes.size()>1){ treeletRoot = simulateCoalescent(nodes, demFunct, activeTime); } else { treeletRoot = new FlexibleNode(new SimpleTree(nodes.get(0)), nodes.get(0), true); treeletRoot.setHeight(0); } // add the root branch length FlexibleNode infectionNode = new FlexibleNode(); infectionNode.setHeight(activeTime); infectionNode.addChild(treeletRoot); treeletRoot.setLength(activeTime - treeletRoot.getHeight()); infectionNode.setAttribute("Event", unit.infectionEvent); FlexibleTree outTree = new FlexibleTree(infectionNode); for(int i=0; i<outTree.getNodeCount(); i++){ FlexibleNode node = (FlexibleNode)outTree.getNode(i); node.setAttribute("Unit", unit.id); } return outTree; } private ArrayList<FlexibleTree> makeTrees(){ // find the first case ArrayList<InfectedUnit> introducedCases = new ArrayList<InfectedUnit>(); for(InfectedUnit unit : units){ if(unit.parent==null){ introducedCases.add(unit); } } if(introducedCases.size()==0){ throw new RuntimeException("Can't find a first case"); } ArrayList<FlexibleTree> out = new ArrayList<FlexibleTree>(); for(InfectedUnit introduction : introducedCases) { coalescentProbability = 1; System.out.println("Building tree for descendants of " + introduction.id); FlexibleNode outTreeRoot = makeSubtree(introduction); if (outTreeRoot != null) { FlexibleTree finalTree = new FlexibleTree(outTreeRoot, false, true); finalTree.setAttribute("firstCase", introduction.id); out.add(finalTree); if(coalescentProbability<0.9){ progressStream.println("WARNING: any phylogeny for descendants of "+introduction.id+" is quite " + "improbable (p<"+(coalescentProbability)+") given this demographic function. Consider " + "another."); } } else { progressStream.println("This individual has no sampled descendants"); } System.out.println(); } return out; } // make the tree from this unit up private FlexibleNode makeSubtree(InfectedUnit unit){ HashMap<Event, FlexibleNode> eventToSubtreeRoot = new HashMap<Event, FlexibleNode>(); ArrayList<Event> relevantEvents = new ArrayList<Event>(); for(Event event : unit.childEvents){ if(event.type == EventType.INFECTION){ FlexibleNode childSubtreeRoot = makeSubtree(event.infectee); if(childSubtreeRoot!=null){ relevantEvents.add(event); eventToSubtreeRoot.put(event, childSubtreeRoot); } } else if(event.type==EventType.SAMPLE) { relevantEvents.add(event); } } FlexibleTree unitTreelet = makeTreelet(unit, relevantEvents); if(unitTreelet==null){ return null; } for(int i=0; i<unitTreelet.getExternalNodeCount(); i++){ FlexibleNode tip = (FlexibleNode)unitTreelet.getExternalNode(i); Event tipEvent = (Event)unitTreelet.getNodeAttribute(tip, "Event"); if(tipEvent.type == EventType.INFECTION){ FlexibleNode subtreeRoot = eventToSubtreeRoot.get(tipEvent); FlexibleNode firstSubtreeSplit = subtreeRoot.getChild(0); subtreeRoot.removeChild(firstSubtreeSplit); tip.addChild(firstSubtreeSplit); } } return (FlexibleNode)unitTreelet.getRoot(); } private FlexibleNode simulateCoalescent(ArrayList<SimpleNode> nodes, DemographicFunction demogFunct, double maxHeight){ double earliestNodeHeight = Double.NEGATIVE_INFINITY; for(SimpleNode node : nodes){ if(node.getHeight()>earliestNodeHeight){ earliestNodeHeight = node.getHeight(); } } double maxLastInterval = earliestNodeHeight; double probNoCoalesenceInTime = Math.exp(demogFunct.getIntensity(maxLastInterval)); coalescentProbability *= (1-probNoCoalesenceInTime); CoalescentSimulator simulator = new CoalescentSimulator(); SimpleNode root; SimpleNode[] simResults; int failCount = 0; do { simResults = simulator.simulateCoalescent(nodes.toArray(new SimpleNode[nodes.size()]), demogFunct, -maxHeight, 0, true); if(simResults.length>1){ failCount++; System.out.println("Failed to coalesce lineages: "+failCount); } } while(simResults.length!=1); root = simResults[0]; SimpleTree simpleTreelet = new SimpleTree(root); for (int i=0; i<simpleTreelet.getNodeCount(); i++) { SimpleNode node = (SimpleNode)simpleTreelet.getNode(i); node.setHeight(node.getHeight() + maxHeight); } return new FlexibleNode(simpleTreelet, root, true); } private FlexibleTree makeWellBehavedTree(FlexibleTree tree){ FlexibleTree newPhylogeneticTree = new FlexibleTree(tree, false); newPhylogeneticTree.beginTreeEdit(); for(int i=0; i<newPhylogeneticTree.getInternalNodeCount(); i++){ FlexibleNode node = (FlexibleNode)newPhylogeneticTree.getInternalNode(i); if(newPhylogeneticTree.getChildCount(node)==1){ FlexibleNode parent = (FlexibleNode)newPhylogeneticTree.getParent(node); FlexibleNode child = (FlexibleNode)newPhylogeneticTree.getChild(node, 0); if(parent!=null){ double childHeight = newPhylogeneticTree.getNodeHeight(child); newPhylogeneticTree.removeChild(parent, node); newPhylogeneticTree.addChild(parent, child); newPhylogeneticTree.setNodeHeight(child, childHeight); } else { child.setParent(null); newPhylogeneticTree.setRoot(child); } } } newPhylogeneticTree.endTreeEdit(); return new FlexibleTree(newPhylogeneticTree, true); } private class InfectedUnit{ private String id; private ArrayList<Event> childEvents; private Event infectionEvent; private InfectedUnit parent; private InfectedUnit(String id){ this.id = id; parent = null; childEvents = new ArrayList<Event>(); } private void addSamplingEvent(double time){ if(time < infectionEvent.time){ throw new RuntimeException("Adding an event to case "+id+" before its infection time"); } childEvents.add(new Event(EventType.SAMPLE, time)); } private void setInfectionEvent(double time, InfectedUnit infector){ setInfectionEvent(new Event(EventType.INFECTION, time, infector, this)); } private void setInfectionEvent(Event event){ for(Event childEvent : childEvents){ if(event.time > childEvent.time){ throw new RuntimeException("Setting infection time for case "+id+" after an existing child event"); } } infectionEvent = event; } private void addChildInfectionEvent(double time, InfectedUnit infectee){ addInfectionEvent(new Event(EventType.INFECTION, time, this, infectee)); } private void addInfectionEvent(Event event){ if(infectionEvent!=null && event.time < infectionEvent.time){ throw new RuntimeException("Adding an infection event to case "+id+" at "+event.time+" before its " + "infection time at "+infectionEvent.time); } childEvents.add(event); } private void sortEvents(){ Collections.sort(childEvents); Collections.reverse(childEvents); } } private class Event implements Comparable<Event>{ private EventType type; private double time; private InfectedUnit infector; private InfectedUnit infectee; private Event(EventType type, double time){ this.type = type; this.time = time; } private Event(EventType type, double time, InfectedUnit infector, InfectedUnit infectee){ this.type = type; this.time = time; this.infector = infector; this.infectee = infectee; } public int compareTo(Event event) { return Double.compare(time, event.time); } } public static void printUsage(Arguments arguments) { arguments.printUsage("virusTreeBuilder", "<infections-file-name> <sample-file-name> <output-file-name-root>"); } public static void main(String[] args){ ModelType model = ModelType.CONSTANT; double startNe = 1; double growthRate = 0; double t50 = 0; Arguments arguments = new Arguments( new Arguments.Option[]{ new Arguments.StringOption(DEMOGRAPHIC_MODEL, demographics, false, "The type of within-host" + " demographic function to use, default = constant"), new Arguments.RealOption(STARTING_POPULATION_SIZE,"The effective population size at time zero" + " (used in all models), default = 1"), new Arguments.RealOption(GROWTH_RATE,"The effective population size growth rate (used in" + " exponential and logistic models), default = 0"), new Arguments.RealOption(T50,"The time point, relative to the time of infection in backwards " + "time, at which the population is equal to half its final asymptotic value, in the " + "logistic model default = 0") }); try { arguments.parseArguments(args); } catch (Arguments.ArgumentException ae) { System.out.println(ae); printUsage(arguments); System.exit(1); } if (arguments.hasOption(HELP)) { printUsage(arguments); System.exit(0); } if (arguments.hasOption(DEMOGRAPHIC_MODEL)) { String modelString = arguments.getStringOption(DEMOGRAPHIC_MODEL); if(modelString.toLowerCase().startsWith("c")){ model = ModelType.CONSTANT; } else if(modelString.toLowerCase().startsWith("e")){ model = ModelType.EXPONENTIAL; } else if(modelString.toLowerCase().startsWith("l")){ model = ModelType.LOGISTIC; } else { progressStream.print("Unrecognised demographic model type"); System.exit(1); } } if(arguments.hasOption(STARTING_POPULATION_SIZE)){ startNe = arguments.getRealOption(STARTING_POPULATION_SIZE); } if(arguments.hasOption(GROWTH_RATE) && model!=ModelType.CONSTANT){ growthRate = arguments.getRealOption(GROWTH_RATE); } if(arguments.hasOption(T50) && model==ModelType.LOGISTIC){ t50 = arguments.getRealOption(T50); } DemographicFunction demoFunction = null; switch(model){ case CONSTANT: { demoFunction = new ConstantPopulation(Units.Type.YEARS); ((ConstantPopulation)demoFunction).setN0(startNe); } case EXPONENTIAL: { demoFunction = new ExponentialGrowth(Units.Type.YEARS); ((ExponentialGrowth)demoFunction).setN0(startNe); ((ExponentialGrowth)demoFunction).setGrowthRate(growthRate); } case LOGISTIC: { demoFunction = new LogisticGrowthN0(Units.Type.YEARS); ((LogisticGrowthN0)demoFunction).setN0(startNe); ((LogisticGrowthN0)demoFunction).setGrowthRate(growthRate); ((LogisticGrowthN0)demoFunction).setT50(t50); } } final String[] args2 = arguments.getLeftoverArguments(); if(args2.length!=3){ printUsage(arguments); System.exit(1); } String infectionsFileName = args2[0]; String samplesFileName = args2[1]; String outputFileRoot = args2[2]; TransmissionTreeToVirusTree instance = new TransmissionTreeToVirusTree(samplesFileName, infectionsFileName, demoFunction, outputFileRoot); try{ instance.run(); } catch (IOException e){ e.printStackTrace(); } } }