/**
*
*/
package test.dr.evomodel.operators;
import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.io.NexusImporter;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Units;
import dr.evomodel.speciation.BirthDeathGernhard08Model;
import dr.evomodel.speciation.SpeciationLikelihood;
import dr.evomodel.speciation.SpeciationModel;
import dr.evomodel.tree.TreeHeightStatistic;
import dr.evomodel.tree.TreeLengthStatistic;
import dr.evomodel.tree.TreeLogger;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.coalescent.OldCoalescentSimulatorParser;
import dr.inference.loggers.MCLogger;
import dr.inference.loggers.TabDelimitedFormatter;
import dr.inference.mcmc.MCMC;
import dr.inference.mcmc.MCMCOptions;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.OperatorSchedule;
import dr.math.MathUtils;
import junit.framework.TestCase;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
/**
* @author Sebastian Hoehna
*
*/
public abstract class OperatorAssert extends TestCase {
static final String TL = "TL";
static final String TREE_HEIGHT = OldCoalescentSimulatorParser.ROOT_HEIGHT;
protected FlexibleTree tree5;
protected FlexibleTree tree6;
public void setUp() throws Exception {
super.setUp();
MathUtils.setSeed(666);
NewickImporter importer = new NewickImporter(
"((((A:1.0,B:1.0):1.0,C:2.0):1.0,D:3.0):1.0,E:4.0);");
tree5 = (FlexibleTree) importer.importTree(null);
importer = new NewickImporter(
"(((((A:1.0,B:1.0):1.0,C:2.0):1.0,D:3.0):1.0,E:4.0),F:5.0);");
tree6 = (FlexibleTree) importer.importTree(null);
}
// 5 taxa trees should sample all 105 topologies
public void testIrreducibility5() throws IOException, Importer.ImportException {
irreducibilityTester(tree5, 105, 1000000, 10);
}
// 6 taxa trees should sample all 945 topologies
public void testIrreducibility6() throws IOException, Importer.ImportException {
irreducibilityTester(tree6, 945, 2000000, 4);
}
/**
* @param ep the expected (binomial) probability of success
* @param ap the actual proportion of successes
* @param count the number of attempts
*/
protected void assertExpectation(double ep, double ap, int count) {
if (count * ap < 5 || count * (1 - ap) < 5) throw new IllegalArgumentException();
double stdev = Math.sqrt(ap * (1.0 - ap) * count) / count;
double upper = ap + 2 * stdev;
double lower = ap - 2 * stdev;
assertTrue("Expected p=" + ep + " but got " + ap + " +/- " + stdev,
upper > ep && lower < ep);
}
private void irreducibilityTester(Tree tree, int numLabelledTopologies, int chainLength, int sampleTreeEvery)
throws IOException, Importer.ImportException {
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(chainLength);
TreeModel treeModel = new TreeModel("treeModel", tree);
TreeLengthStatistic tls = new TreeLengthStatistic(TL, treeModel);
TreeHeightStatistic rootHeight = new TreeHeightStatistic(TREE_HEIGHT, treeModel);
OperatorSchedule schedule = getOperatorSchedule(treeModel);
Parameter b = new Parameter.Default("b", 2.0, 0.0, Double.MAX_VALUE);
Parameter d = new Parameter.Default("d", 0.0, 0.0, Double.MAX_VALUE);
SpeciationModel speciationModel = new BirthDeathGernhard08Model(b, d, null, BirthDeathGernhard08Model.TreeType.UNSCALED,
Units.Type.YEARS);
Likelihood likelihood = new SpeciationLikelihood(treeModel, speciationModel, "yule.like");
MCLogger[] loggers = new MCLogger[2];
// loggers[0] = new MCLogger(new ArrayLogFormatter(false), 100, false);
// loggers[0].add(likelihood);
// loggers[0].add(rootHeight);
// loggers[0].add(tls);
loggers[0] = new MCLogger(new TabDelimitedFormatter(System.out), 10000, false);
loggers[0].add(likelihood);
loggers[0].add(rootHeight);
loggers[0].add(tls);
File file = new File("yule.trees");
file.deleteOnExit();
FileOutputStream out = new FileOutputStream(file);
loggers[1] = new TreeLogger(treeModel, new TabDelimitedFormatter(out), sampleTreeEvery, true, true, false);
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, likelihood, schedule, loggers);
mcmc.run();
out.flush();
out.close();
Set<String> uniqueTrees = new HashSet<String>();
HashMap<String, Integer> topologies = new HashMap<String, Integer>();
HashMap<String, HashMap<String, Integer>> treeCounts = new HashMap<String, HashMap<String, Integer>>();
NexusImporter importer = new NexusImporter(new FileReader(file));
int sampleSize = 0;
while (importer.hasTree()) {
sampleSize++;
Tree t = importer.importNextTree();
String uniqueNewick = TreeUtils.uniqueNewick(t, t.getRoot());
String topology = uniqueNewick.replaceAll("\\w+", "X");
if (!uniqueTrees.contains(uniqueNewick)){
uniqueTrees.add(uniqueNewick);
}
HashMap<String, Integer> counts;
if (topologies.containsKey(topology)){
topologies.put(topology, topologies.get(topology)+1);
counts = treeCounts.get(topology);
}
else {
topologies.put(topology, 1);
counts = new HashMap<String, Integer>();
treeCounts.put(topology, counts);
}
if (counts.containsKey(uniqueNewick)){
counts.put(uniqueNewick, counts.get(uniqueNewick)+1);
}
else {
counts.put(uniqueNewick, 1);
}
}
TestCase.assertEquals(numLabelledTopologies, uniqueTrees.size());
TestCase.assertEquals(sampleSize, chainLength / sampleTreeEvery + 1);
Set<String> keys = topologies.keySet();
double ep = 1.0 / topologies.size();
for (String topology : keys){
double ap = ((double)topologies.get(topology)) / (sampleSize);
// assertExpectation(ep, ap, sampleSize);
HashMap<String, Integer> counts = treeCounts.get(topology);
Set<String> trees = counts.keySet();
double MSE = 0;
double ep1 = 1.0 / counts.size();
for (String t : trees){
double ap1 = ((double)counts.get(t)) / (topologies.get(topology));
// assertExpectation(ep1, ap1, topologies.get(topology));
MSE += (ep1-ap1)*(ep1-ap1);
}
MSE /= counts.size();
System.out.println("The Mean Square Error for the topolgy " + topology + " is " + MSE);
}
}
public abstract OperatorSchedule getOperatorSchedule(TreeModel treeModel);
}