/*
* TraceAnalysis.java
*
* Copyright (C) 2002-2006 Alexei Drummond and Andrew Rambaut
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.trace;
import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.List;
/**
* @author Alexei Drummond
* @version $Id: TraceAnalysis.java,v 1.23 2005/05/24 20:26:00 rambaut Exp $
*/
public class TraceAnalysis {
/**
* @param fileName the name of the log file to analyze
* @param burnin the state to discard up to
* @return an array og analyses of the statistics in a log file.
* @throws java.io.IOException if general error reading file
* @throws TraceException if trace file in wrong format or corrupted
*/
public static LogFileTraces analyzeLogFile(String fileName, int burnin) throws java.io.IOException, TraceException {
File file = new File(fileName);
LogFileTraces traces = new LogFileTraces(fileName, file);
traces.loadTraces();
traces.setBurnIn(burnin);
for (int i = 0; i < traces.getTraceCount(); i++) {
traces.analyseTrace(i);
}
return traces;
}
public static TraceList report(String fileName) throws java.io.IOException, TraceException {
return report(fileName, -1, null);
}
public static TraceList report(String fileName, int burnin, String likelihoodName) throws java.io.IOException, TraceException {
return report(fileName, burnin, likelihoodName, true);
}
public static TraceList report(String fileName, int inBurnin, String likelihoodName, boolean withStdError)
throws java.io.IOException, TraceException {
// int fieldWidth = 14;
// int firstField = 25;
// NumberFormatter formatter = new NumberFormatter(4);
// formatter.setPadding(true);
// formatter.setFieldWidth(fieldWidth);
File file = new File(fileName);
LogFileTraces traces = new LogFileTraces(fileName, file);
// if (traces == null) {
// throw new TraceException("Trace file is empty.");
// }
traces.loadTraces();
// traces.addTrace("R0", traces.getTraceIndex("bdss.psi"));
int burnin = inBurnin;
if (burnin == -1) {
burnin = traces.getMaxState() / 10;
}
traces.setBurnIn(burnin);
// System.out.println();
System.out.println("burnIn <= " + burnin + ", maxState = " + traces.getMaxState());
// System.out.println();
System.out.print("statistic");
String[] names;
if (!withStdError)
names = new String[]{"mean", "hpdLower", "hpdUpper", "ESS"};
else
names = new String[]{"mean", "stdErr", "median", "hpdLower", "hpdUpper", "ESS", "50hpdLower", "50hpdUpper"};
for (String name : names) {
System.out.print("\t" + name);
}
System.out.println();
int warning = 0;
for (int i = 0; i < traces.getTraceCount(); i++) {
traces.analyseTrace(i);
TraceDistribution distribution = traces.getDistributionStatistics(i);
double ess = distribution.getESS();
System.out.print(traces.getTraceName(i));
System.out.print("\t" + formattedNumber(distribution.getMean()));
if (withStdError) {
System.out.print("\t" + formattedNumber(distribution.getStdError()));
System.out.print("\t" + formattedNumber(distribution.getMedian()));
}
System.out.print("\t" + formattedNumber(distribution.getLowerHPD()));
System.out.print("\t" + formattedNumber(distribution.getUpperHPD()));
System.out.print("\t" + formattedNumber(ess));
if (withStdError) {
System.out.print("\t" + formattedNumber(distribution.getHpdLowerCustom()));
System.out.print("\t" + formattedNumber(distribution.getHpdUpperCustom()));
}
if (ess < 100) {
warning += 1;
System.out.println("\t" + "*");
} else {
System.out.println("\t");
}
}
System.out.println();
if (warning > 0) {
System.out.println(" * WARNING: The results of this MCMC analysis may be invalid as ");
System.out.println(" one or more statistics had very low effective sample sizes (ESS)");
}
if (likelihoodName != null) {
System.out.println();
int traceIndex = -1;
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
if (traceName.equals(likelihoodName)) {
traceIndex = i;
break;
}
}
if (traceIndex == -1) {
throw new TraceException("Column '" + likelihoodName +
"' can not be found for marginal likelihood analysis.");
}
String analysisType = "aicm";
int bootstrapLength = 1000;
List<Double> sample = traces.getValues(traceIndex);
MarginalLikelihoodAnalysis analysis = new MarginalLikelihoodAnalysis(sample,
traces.getTraceName(traceIndex), burnin, analysisType, bootstrapLength);
System.out.println(analysis.toString());
}
System.out.flush();
return traces;
}
public static void reportTrace(String fileName, int inBurnin, String traceName) throws IOException, TraceException {
File file = new File(fileName);
LogFileTraces traces = new LogFileTraces(fileName, file);
traces.loadTraces();
int burnin = inBurnin;
if (burnin == -1) {
burnin = traces.getMaxState() / 10;
}
traces.setBurnIn(burnin);
// System.out.println();
// System.out.println("burnIn <= " + burnin + ", maxState = " + traces.getMaxState());
// System.out.println();
// System.out.print("statistic");
// String[] names = new String[]{"mean", "stdErr", "median", "hpdLower", "hpdUpper", "50hpdLower", "50hpdUpper"};//, "ESS"};
//
// for (String name : names) {
// System.out.print("\t" + name);
// }
// System.out.println();
int id = traces.getTraceIndex(traceName);
traces.analyseTrace(id);
TraceDistribution distribution = traces.getDistributionStatistics(id);
double ess = distribution.getESS();
// System.out.print(traces.getTraceName(id) + "\t");
System.out.print(formattedNumber(distribution.getMean()) + "\t");
System.out.print(formattedNumber(distribution.getStdError()) + "\t");
System.out.print(formattedNumber(distribution.getMedian()) + "\t");
System.out.print(formattedNumber(distribution.getLowerHPD()) + "\t");
System.out.print(formattedNumber(distribution.getUpperHPD()) + "\t");
System.out.print(formattedNumber(distribution.getHpdLowerCustom()) + "\t");
System.out.print(formattedNumber(distribution.getHpdUpperCustom()) + "\t");
System.out.println();
// System.out.print(SummaryStatisticsPanel.formattedNumber(ess));
}
/**
* @param burnin the number of states of burnin or if -1 then use 10%
* @param filename the file name of the log file to report on
* @param drawHeader if true then draw header
* @param stdErr if true then report the standard deviation of the mean
* @param hpds if true then report 95% hpd upper and lower
* @param individualESSs minimum number of ESS with which to throw warning
* @param likelihoodName column name
* @return the traces loaded from given file to create this short report
* @throws java.io.IOException if general error reading file
* @throws TraceException if trace file in wrong format or corrupted
*/
public static TraceList shortReport(String filename,
final int burnin, boolean drawHeader,
boolean hpds, boolean individualESSs, boolean stdErr,
String likelihoodName) throws java.io.IOException, TraceException {
TraceList traces = analyzeLogFile(filename, burnin);
int maxState = traces.getMaxState();
double minESS = Double.MAX_VALUE;
if (drawHeader) {
System.out.print("file\t");
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
System.out.print(traceName + "\t");
if (stdErr)
System.out.print(traceName + " stdErr\t");
if (hpds) {
System.out.print(traceName + " hpdLower\t");
System.out.print(traceName + " hpdUpper\t");
}
if (individualESSs) {
System.out.print(traceName + " ESS\t");
}
}
System.out.print("minESS\t");
if (likelihoodName != null) {
System.out.print("marginal likelihood\t");
System.out.print("stdErr\t");
}
System.out.println("chainLength");
}
System.out.print(filename + "\t");
for (int i = 0; i < traces.getTraceCount(); i++) {
//TraceDistribution distribution = traces.getDistributionStatistics(i);
TraceCorrelation distribution = traces.getCorrelationStatistics(i);
System.out.print(distribution.getMean() + "\t");
if (stdErr)
System.out.print(distribution.getStdErrorOfMean() + "\t");
if (hpds) {
System.out.print(distribution.getLowerHPD() + "\t");
System.out.print(distribution.getUpperHPD() + "\t");
}
if (individualESSs) {
System.out.print(distribution.getESS() + "\t");
}
double ess = distribution.getESS();
if (ess < minESS) {
minESS = ess;
}
}
System.out.print(minESS + "\t");
if (likelihoodName != null) {
int traceIndex = -1;
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
if (traceName.equals(likelihoodName)) {
traceIndex = i;
break;
}
}
if (traceIndex == -1) {
throw new TraceException("Column '" + likelihoodName + "' can not be found in file " + filename + ".");
}
String analysisType = "aicm";
int bootstrapLength = 1000;
List<Double> sample = traces.getValues(traceIndex);
MarginalLikelihoodAnalysis analysis = new MarginalLikelihoodAnalysis(sample,
traces.getTraceName(traceIndex), burnin, analysisType, bootstrapLength);
System.out.print(analysis.getLogMarginalLikelihood() + "\t");
System.out.print(analysis.getBootstrappedSE() + "\t");
}
System.out.println(maxState);
return traces;
}
public static String formattedNumber(double value) {
DecimalFormat formatter = new DecimalFormat("0.####E0");
DecimalFormat formatter2 = new DecimalFormat("####0.####");
if (value > 0 && (Math.abs(value) < 0.01 || Math.abs(value) >= 100000.0)) {
return formatter.format(value);
} else return formatter2.format(value);
}
}