/*
* CaseToCaseTransmissionLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.epidemiology.casetocase;
import dr.app.tools.NexusExporter;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.*;
import dr.xml.*;
import java.io.PrintStream;
import java.util.*;
/**
* A likelihood function for transmission between identified epidemiological outbreak
*
* Timescale must be in days. Python scripts to write XML for it and analyse the posterior set of networks exist;
* contact MH.
*
* Latent periods are not implemented currently
*
* @author Matthew Hall
* @version $Id: $
*/
public class CaseToCaseTransmissionLikelihood extends AbstractModelLikelihood implements Loggable {
private static final boolean DEBUG = false;
private CategoryOutbreak outbreak;
private CaseToCaseTreeLikelihood treeLikelihood;
private SpatialKernel spatialKernel;
private Parameter transmissionRate;
private boolean likelihoodKnown;
private boolean storedLikelihoodKnown;
private boolean transProbKnown;
private boolean storedTransProbKnown;
private boolean periodsProbKnown;
private boolean storedPeriodsProbKnown;
private boolean treeProbKnown;
private boolean storedTreeProbKnown;
private double logLikelihood;
private double storedLogLikelihood;
private double transLogProb;
private double storedTransLogProb;
private double periodsLogProb;
private double storedPeriodsLogProb;
private double treeLogProb;
private double storedTreeLogProb;
private ParametricDistributionModel initialInfectionTimePrior;
private HashMap<AbstractCase, Double> indexCasePrior;
private final boolean hasGeography;
private final boolean hasLatentPeriods;
private ArrayList<TreeEvent> sortedTreeEvents;
private ArrayList<TreeEvent> storedSortedTreeEvents;
private AbstractCase indexCase;
private AbstractCase storedIndexCase;
// private F f;
public static final String CASE_TO_CASE_TRANSMISSION_LIKELIHOOD = "caseToCaseTransmissionLikelihood";
public CaseToCaseTransmissionLikelihood(String name, CategoryOutbreak outbreak,
CaseToCaseTreeLikelihood treeLikelihood, SpatialKernel spatialKernal,
Parameter transmissionRate,
ParametricDistributionModel intialInfectionTimePrior){
super(name);
this.outbreak = outbreak;
this.treeLikelihood = treeLikelihood;
this.spatialKernel = spatialKernal;
if(spatialKernal!=null){
this.addModel(spatialKernal);
}
this.transmissionRate = transmissionRate;
this.addModel(treeLikelihood);
this.addVariable(transmissionRate);
likelihoodKnown = false;
hasGeography = spatialKernal!=null;
this.hasLatentPeriods = treeLikelihood.hasLatentPeriods();
this.initialInfectionTimePrior = intialInfectionTimePrior;
HashMap<AbstractCase, Double> weightMap = outbreak.getWeightMap();
double totalWeights = 0;
for(AbstractCase aCase : weightMap.keySet()){
if(aCase.wasEverInfected) {
totalWeights += weightMap.get(aCase);
}
}
indexCasePrior = new HashMap<AbstractCase, Double>();
for(AbstractCase aCase : outbreak.getCases()){
if(aCase.wasEverInfected) {
indexCasePrior.put(aCase, weightMap.get(aCase) / totalWeights);
}
}
sortEvents();
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
if(model instanceof CaseToCaseTreeLikelihood){
treeProbKnown = false;
if(!(object instanceof DemographicModel)){
transProbKnown = false;
periodsProbKnown = false;
sortedTreeEvents = null;
indexCase = null;
}
} else if(model instanceof SpatialKernel){
transProbKnown = false;
} else if(model instanceof AbstractOutbreak){
transProbKnown = false;
periodsProbKnown = false;
sortedTreeEvents = null;
indexCase = null;
}
likelihoodKnown = false;
}
// no need to change the RNG queue unless the normalisation will have changed
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if(variable==transmissionRate){
transProbKnown = false;
}
likelihoodKnown = false;
}
protected void storeState() {
storedLogLikelihood = logLikelihood;
storedLikelihoodKnown = likelihoodKnown;
storedPeriodsLogProb = periodsLogProb;
storedPeriodsProbKnown = periodsProbKnown;
storedTransLogProb = transLogProb;
storedTransProbKnown = transProbKnown;
storedTreeLogProb = treeLogProb;
storedTreeProbKnown = treeProbKnown;
storedSortedTreeEvents = new ArrayList<TreeEvent>(sortedTreeEvents);
storedIndexCase = indexCase;
}
protected void restoreState() {
logLikelihood = storedLogLikelihood;
likelihoodKnown = storedLikelihoodKnown;
transLogProb = storedTransLogProb;
transProbKnown = storedTransProbKnown;
treeLogProb = storedTreeLogProb;
treeProbKnown = storedTreeProbKnown;
periodsLogProb = storedPeriodsLogProb;
periodsProbKnown = storedPeriodsProbKnown;
sortedTreeEvents = storedSortedTreeEvents;
indexCase = storedIndexCase;
}
protected void acceptState() {
// nothing to do
}
public SpatialKernel getSpatialKernel(){
return spatialKernel;
}
public Model getModel() {
return this;
}
public CaseToCaseTreeLikelihood getTreeLikelihood(){
return treeLikelihood;
}
public double getLogLikelihood() {
if(!likelihoodKnown) {
if (!treeProbKnown) {
treeLikelihood.prepareTimings();
}
if (!transProbKnown) {
try {
transLogProb = 0;
if (sortedTreeEvents == null) {
sortEvents();
}
double rate = transmissionRate.getParameterValue(0);
ArrayList<AbstractCase> previouslyInfectious = new ArrayList<AbstractCase>();
double currentEventTime;
boolean first = true;
for (TreeEvent event : sortedTreeEvents) {
currentEventTime = event.getTime();
AbstractCase thisCase = event.getCase();
if (event.getType() == EventType.INFECTION) {
if (first) {
// index infection
if (indexCasePrior != null) {
transLogProb += Math.log(indexCasePrior.get(thisCase));
}
if (initialInfectionTimePrior != null) {
transLogProb += initialInfectionTimePrior.logPdf(currentEventTime);
}
if (!hasLatentPeriods) {
previouslyInfectious.add(thisCase);
}
first = false;
} else {
AbstractCase infector = event.getInfector();
if(thisCase.wasEverInfected()) {
if (previouslyInfectious.contains(thisCase)){
throw new BadPartitionException(thisCase.caseID +
" infected after it was infectious");
}
if (event.getTime() > thisCase.endOfInfectiousTime){
throw new BadPartitionException(thisCase.caseID +
" ceased to be infected before it was infected");
}
if (infector.endOfInfectiousTime < event.getTime()){
throw new BadPartitionException(thisCase.caseID + " infected by "
+ infector.caseID + " after the latter ceased to be infectious");
}
if (treeLikelihood.getInfectiousTime(infector) > event.getTime()) {
throw new BadPartitionException(thisCase.caseID + " infected by "
+ infector.caseID + " before the latter became infectious");
}
if(!previouslyInfectious.contains(infector)){
throw new RuntimeException("Infector not previously infected");
}
}
// no other previously infectious case has infected this case...
for (AbstractCase nonInfector : previouslyInfectious) {
double timeDuringWhichNoInfection;
if (nonInfector.endOfInfectiousTime < event.getTime()) {
timeDuringWhichNoInfection = nonInfector.endOfInfectiousTime
- treeLikelihood.getInfectiousTime(nonInfector);
} else {
timeDuringWhichNoInfection = event.getTime()
- treeLikelihood.getInfectiousTime(nonInfector);
}
if(timeDuringWhichNoInfection<0){
throw new RuntimeException("negative time");
}
double transRate = rate;
if (hasGeography) {
transRate *= outbreak.getKernelValue(thisCase, nonInfector, spatialKernel);
}
transLogProb += -transRate * timeDuringWhichNoInfection;
}
// ...until the end
if(thisCase.wasEverInfected()) {
double transRate = rate;
if (hasGeography) {
transRate *= outbreak.getKernelValue(thisCase, infector, spatialKernel);
}
transLogProb += Math.log(transRate);
}
if (!hasLatentPeriods) {
previouslyInfectious.add(thisCase);
}
}
} else if (event.getType() == EventType.INFECTIOUSNESS) {
if (event.getTime() < Double.POSITIVE_INFINITY) {
if(event.getTime() > event.getCase().endOfInfectiousTime){
throw new BadPartitionException(event.getCase().caseID + " noninfectious before" +
"infectious");
}
if (first) {
throw new RuntimeException("First event is not an infection");
}
previouslyInfectious.add(thisCase);
}
}
}
transProbKnown = true;
} catch (BadPartitionException e) {
transLogProb = Double.NEGATIVE_INFINITY;
transProbKnown = true;
logLikelihood = Double.NEGATIVE_INFINITY;
likelihoodKnown = true;
return logLikelihood;
}
}
if(!periodsProbKnown){
periodsLogProb = 0;
HashMap<String, ArrayList<Double>> infectiousPeriodsByCategory
= new HashMap<String, ArrayList<Double>>();
for (AbstractCase aCase : outbreak.getCases()) {
if(aCase.wasEverInfected()) {
String category = (outbreak).getInfectiousCategory(aCase);
if (!infectiousPeriodsByCategory.keySet().contains(category)) {
infectiousPeriodsByCategory.put(category, new ArrayList<Double>());
}
ArrayList<Double> correspondingList
= infectiousPeriodsByCategory.get(category);
correspondingList.add(treeLikelihood.getInfectiousPeriod(aCase));
}
}
for (String category : outbreak.getInfectiousCategories()) {
Double[] infPeriodsInThisCategory = infectiousPeriodsByCategory.get(category)
.toArray(new Double[infectiousPeriodsByCategory.get(category).size()]);
AbstractPeriodPriorDistribution hyperprior = outbreak.getInfectiousCategoryPrior(category);
double[] values = new double[infPeriodsInThisCategory.length];
for (int i = 0; i < infPeriodsInThisCategory.length; i++) {
values[i] = infPeriodsInThisCategory[i];
}
periodsLogProb += hyperprior.getLogLikelihood(values);
}
periodsProbKnown = true;
}
if(!treeProbKnown){
treeLogProb = treeLikelihood.getLogLikelihood();
treeProbKnown = true;
}
// just reject states where these round to +INF
if(transLogProb == Double.POSITIVE_INFINITY){
System.out.println("TransLogProb +INF");
return Double.NEGATIVE_INFINITY;
}
if(periodsLogProb == Double.POSITIVE_INFINITY){
System.out.println("PeriodsLogProb +INF");
return Double.NEGATIVE_INFINITY;
}
if(treeLogProb == Double.POSITIVE_INFINITY){
System.out.println("TreeLogProb +INF");
return Double.NEGATIVE_INFINITY;
}
logLikelihood = treeLogProb + periodsLogProb + transLogProb;
likelihoodKnown = true;
}
return logLikelihood;
}
public void makeDirty() {
likelihoodKnown = false;
transProbKnown = false;
periodsProbKnown = false;
treeProbKnown = false;
sortedTreeEvents = null;
treeLikelihood.makeDirty();
indexCase = null;
}
private class EventComparator implements Comparator<TreeEvent> {
public int compare(TreeEvent treeEvent1, TreeEvent treeEvent2) {
return Double.compare(treeEvent1.getTime(),
treeEvent2.getTime());
}
}
private enum EventType{
INFECTION,
INFECTIOUSNESS,
END
}
private void sortEvents(){
ArrayList<TreeEvent> out = new ArrayList<TreeEvent>();
for(AbstractCase aCase : outbreak.getCases()){
double infectionTime = treeLikelihood.getInfectionTime(aCase);
out.add(new TreeEvent(infectionTime, aCase, treeLikelihood.getInfector(outbreak.getCaseIndex(aCase))));
if(aCase.wasEverInfected()) {
double endTime = aCase.endOfInfectiousTime;
out.add(new TreeEvent(EventType.END, endTime, aCase));
if (hasLatentPeriods) {
double infectiousnessTime = treeLikelihood.getInfectiousTime(aCase);
out.add(new TreeEvent(EventType.INFECTIOUSNESS, infectiousnessTime, aCase));
}
}
}
Collections.sort(out, new EventComparator());
indexCase = out.get(0).getCase();
sortedTreeEvents = out;
}
private class TreeEvent{
private EventType type;
private double time;
private AbstractCase aCase;
private AbstractCase infectorCase;
private TreeEvent(EventType type, double time, AbstractCase aCase){
this.type = type;
this.time = time;
this.aCase = aCase;
this.infectorCase = null;
}
private TreeEvent(double time, AbstractCase aCase, AbstractCase infectorCase){
this.type = EventType.INFECTION;
this.time = time;
this.aCase = aCase;
this.infectorCase = infectorCase;
}
public double getTime(){
return time;
}
public EventType getType(){
return type;
}
public AbstractCase getCase(){
return aCase;
}
public AbstractCase getInfector(){
return infectorCase;
}
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public static final String TRANSMISSION_RATE = "transmissionRate";
public static final String INITIAL_INFECTION_TIME_PRIOR = "initialInfectionTimePrior";
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
CaseToCaseTreeLikelihood c2cTL = (CaseToCaseTreeLikelihood)
xo.getChild(CaseToCaseTreeLikelihood.class);
SpatialKernel kernel = (SpatialKernel) xo.getChild(SpatialKernel.class);
Parameter transmissionRate = (Parameter) xo.getElementFirstChild(TRANSMISSION_RATE);
ParametricDistributionModel iitp = null;
if(xo.hasChildNamed(INITIAL_INFECTION_TIME_PRIOR)){
iitp = (ParametricDistributionModel)xo.getElementFirstChild(INITIAL_INFECTION_TIME_PRIOR);
}
return new CaseToCaseTransmissionLikelihood(CASE_TO_CASE_TRANSMISSION_LIKELIHOOD,
(CategoryOutbreak)c2cTL.getOutbreak(), c2cTL, kernel, transmissionRate, iitp);
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
public String getParserDescription() {
return "This element represents a probability distribution for epidemiological parameters of an outbreak" +
"given a phylogenetic tree";
}
public Class getReturnType() {
return CaseToCaseTransmissionLikelihood.class;
}
public String getParserName() {
return CASE_TO_CASE_TRANSMISSION_LIKELIHOOD;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(CaseToCaseTreeLikelihood.class, "The tree likelihood"),
new ElementRule(SpatialKernel.class, "The spatial kernel", 0, 1),
new ElementRule(TRANSMISSION_RATE, Parameter.class, "The transmission rate"),
new ElementRule(INITIAL_INFECTION_TIME_PRIOR, ParametricDistributionModel.class, "The prior " +
"probability distibution of the first infection", true)
};
};
// Not the most elegant solution, but you want two types of log out of this model, one for numerical parameters
// (which Tracer can read) and one for the transmission tree (which it cannot). This is set up so that C2CTransL
// is the numerical log and C2CTreeL the TT one.
public LogColumn[] getColumns(){
ArrayList<LogColumn> columns = new ArrayList<LogColumn>();
columns.add(new LogColumn.Abstract("trans_LL"){
protected String getFormattedValue() {
return String.valueOf(transLogProb);
}
});
columns.add(new LogColumn.Abstract("period_LL") {
protected String getFormattedValue() {
return String.valueOf(periodsLogProb);
}
});
columns.addAll(Arrays.asList(treeLikelihood.passColumns()));
for (AbstractPeriodPriorDistribution hyperprior : (outbreak).getInfectiousMap().values()) {
columns.addAll(Arrays.asList(hyperprior.getColumns()));
}
columns.add(new LogColumn.Abstract("FirstInfectionTime") {
protected String getFormattedValue() {
if(sortedTreeEvents==null){
sortEvents();
}
return String.valueOf(treeLikelihood.getInfectionTime(indexCase));
}
});
columns.add(new LogColumn.Abstract("IndexCaseIndex") {
protected String getFormattedValue() {
return String.valueOf(treeLikelihood.getOutbreak().getCaseIndex(indexCase));
}
});
return columns.toArray(new LogColumn[columns.size()]);
}
}