/* * TraceDistribution.java * * Copyright (C) 2002-2006 Alexei Drummond and Andrew Rambaut * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.trace; import dr.stats.DiscreteStatistics; import dr.util.HeapSort; import java.util.*; /** * A class that stores the distribution statistics for a trace * * @author Andrew Rambaut * @author Alexei Drummond * @version $Id: TraceDistribution.java,v 1.1.1.2 2006/04/25 23:00:09 rambaut Exp $ */ public class TraceDistribution<T> { private TraceFactory.TraceType traceType; public TraceDistribution(List<T> values, TraceFactory.TraceType traceType) { this.traceType = traceType; initStatistics(values, 0.95); } public TraceDistribution(List<T> values, TraceFactory.TraceType traceType, double ESS) { this(values, traceType); this.ESS = ESS; } public TraceFactory.TraceType getTraceType() { return traceType; } public void setTraceType(TraceFactory.TraceType traceType) { this.traceType = traceType; } // public String getTraceTypeBrief() { // if (traceType == TraceFactory.TraceType.DOUBLE) { // return TraceFactory.TraceType.DOUBLE.getBrief(); // } else if (traceType == TraceFactory.TraceType.INTEGER) { // return TraceFactory.TraceType.INTEGER.getBrief(); // } else if (traceType == TraceFactory.TraceType.STRING) { // return TraceFactory.TraceType.STRING.getBrief(); // } // throw new IllegalArgumentException("The trace type " + traceType + " is not recognized."); // } public boolean isValid() { return isValid; } public double getMean() { return mean; } public double getVariance() { return variance; } public double getStdError() { return stdError; } public boolean hasGeometricMean() { return hasGeometricMean; } public double getGeometricMean() { return geometricMean; } public double getMedian() { return median; } public double getLowerHPD() { return hpdLower; } public double getUpperHPD() { return hpdUpper; } public double getLowerCPD() { return cpdLower; } public double getUpperCPD() { return cpdUpper; } public double getESS() { return ESS; } public double getMinimum() { return minimum; } public double getMaximum() { return maximum; } public double getHpdLowerCustom() { return hpdLowerCustom; } public double getHpdUpperCustom() { return hpdUpperCustom; } public double getMeanSquaredError(double[] values, double trueValue) { if (values == null) { throw new RuntimeException("Trace values not yet set"); } if (traceType == TraceFactory.TraceType.DOUBLE || traceType == TraceFactory.TraceType.INTEGER) { return DiscreteStatistics.meanSquaredError(values, trueValue); } else { throw new RuntimeException("Require Number Trace Type in the Trace Distribution: " + this); } } /** * @param valuesC the values to analyze */ private void analyseDistributionContinuous(double[] valuesC, double proportion) { // this.values = values; // move to TraceDistribution(T[] values) mean = DiscreteStatistics.mean(valuesC); stdError = DiscreteStatistics.stdev(valuesC); variance = DiscreteStatistics.variance(valuesC); minimum = Double.POSITIVE_INFINITY; maximum = Double.NEGATIVE_INFINITY; for (double value : valuesC) { if (value < minimum) minimum = value; if (value > maximum) maximum = value; } if (minimum > 0) { geometricMean = DiscreteStatistics.geometricMean(valuesC); hasGeometricMean = true; } if (maximum == minimum) { isValid = false; return; } int[] indices = new int[valuesC.length]; HeapSort.sort(valuesC, indices); median = DiscreteStatistics.quantile(0.5, valuesC, indices); cpdLower = DiscreteStatistics.quantile(0.025, valuesC, indices); cpdUpper = DiscreteStatistics.quantile(0.975, valuesC, indices); calculateHPDInterval(proportion, valuesC, indices); ESS = valuesC.length; calculateHPDIntervalCustom(0.5, valuesC, indices); isValid = true; } /** * @param proportion the proportion of probability mass included within interval. * @param array the data array * @param indices the indices of the ranks of the values (sort order) */ private void calculateHPDInterval(double proportion, double[] array, int[] indices) { final double[] hpd = DiscreteStatistics.HPDInterval(proportion, array, indices); hpdLower = hpd[0]; hpdUpper = hpd[1]; } private void calculateHPDIntervalCustom(double proportion, double[] array, int[] indices) { final double[] hpd = DiscreteStatistics.HPDInterval(proportion, array, indices); hpdLowerCustom = hpd[0]; hpdUpperCustom = hpd[1]; } protected boolean isValid = false; protected boolean hasGeometricMean = false; protected double minimum, maximum; protected double mean; protected double median; protected double geometricMean; protected double stdError, meanSquaredError; protected double variance; protected double cpdLower, cpdUpper, hpdLower, hpdUpper; protected double hpdLowerCustom, hpdUpperCustom; protected double ESS; //************************************************************************ // new types //************************************************************************ // <T, frequency> for T = Integer and String public Map<T, Integer> valuesMap = new HashMap<T, Integer>(); // public Map<T, Integer> inCredibleSet = new HashMap<T, Integer>(); public List<T> credibleSet = new ArrayList<T>(); public List<T> inCredibleSet = new ArrayList<T>(); public T mode; public int freqOfMode = 0; public void initStatistics(List<T> values, double proportion) { valuesMap.clear(); credibleSet.clear(); inCredibleSet.clear(); if (values.size() < 1) throw new RuntimeException("There is no value sent to statistics calculation !"); if (traceType == TraceFactory.TraceType.DOUBLE || traceType == TraceFactory.TraceType.INTEGER) { double[] newValues = new double[values.size()]; for (int i = 0; i < values.size(); i++) { newValues[i] = ((Number) values.get(i)).doubleValue(); } analyseDistributionContinuous(newValues, proportion); } if (traceType == TraceFactory.TraceType.STRING || traceType == TraceFactory.TraceType.INTEGER) { for (T value : values) { if (valuesMap.containsKey(value)) { int i = valuesMap.get(value) + 1; valuesMap.put(value, i); } else { valuesMap.put(value, 1); } } for (T value : new TreeSet<T>(valuesMap.keySet())) { double prob = valuesMap.get(value).doubleValue() / (double) values.size(); if (prob < (1 - proportion)) { inCredibleSet.add(value); } else { credibleSet.add(value); } } calculateMode(); isValid = true; // what purpose? } } public boolean inside(T value) { return valuesMap.containsKey(value); } public boolean inside(Double value) { return value <= hpdUpper && value >= hpdLower; } public int getIndex(T value) { int i = -1; for (T v : new TreeSet<T>(valuesMap.keySet())) { i++; if (v.equals(value)) return i; } return i; } public boolean credibleSetContains(int valueORIndex) { return contains(credibleSet, valueORIndex); } public boolean inCredibleSetContains(int valueORIndex) { return contains(inCredibleSet, valueORIndex); } private boolean contains(List<T> list, int valueORIndex) { if (traceType == TraceFactory.TraceType.INTEGER) { return list.contains(valueORIndex); } else { // String String valueString = null; int i = -1; for (T v : new TreeSet<T>(valuesMap.keySet())) { i++; if (i == valueORIndex) valueString = v.toString(); } return list.contains(valueString); } } public T getMode() { return mode; } public int getFrequencyOfMode() { return freqOfMode; } public List<String> getRange() { List<String> valuesList = new ArrayList<String>(); for (T value : new TreeSet<T>(valuesMap.keySet())) { if (traceType == TraceFactory.TraceType.INTEGER) { // as Integer is stored as Double in Trace if (!valuesList.contains(Integer.toString(((Number) value).intValue()))) valuesList.add(Integer.toString(((Number) value).intValue())); } else { if (!valuesList.contains(value.toString())) valuesList.add(value.toString()); } } return valuesList; } private void calculateMode() { for (T value : new TreeSet<T>(valuesMap.keySet())) { if (freqOfMode < valuesMap.get(value)) { freqOfMode = valuesMap.get(value); mode = value; } } } private String printSet(List<T> list) { String line = "{"; for (T value : list) { line = line + value + ", "; } if (line.endsWith(", ")) { line = line.substring(0, line.lastIndexOf(", ")) + "}"; } else { line = "{}"; } return line; } public String printCredibleSet() { return printSet(credibleSet); } public String printInCredibleSet() { return printSet(inCredibleSet); } }