/******************************************************************************* * Copyright (C) 2009-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.inference; import java.io.PrintStream; import java.lang.reflect.InvocationTargetException; import java.util.Vector; import umontreal.iro.lecuyer.probdist.BetaDist; /** * Base class for the representation of sampled distributions. * @author Dominik Jain */ public abstract class BasicSampledDistribution implements IParameterHandler { /** * an array of values representing the distribution, one for each node and each domain element: * values[i][j] is the value for the j-th domain element of the i-th node in the network */ public double[][] values = null; /** * the normalization constant that applies to each of the distribution values */ public Double Z = null; /** * the confidence level for the computation of confidence intervals * if null, no confidence interval computations are carried out */ public Double confidenceLevel = null; public ParameterHandler paramHandler; public BasicSampledDistribution() throws Exception { paramHandler = new ParameterHandler(this); paramHandler.add("confidenceLevel", "setConfidenceLevel"); } public double getProbability(int varIdx, int domainIdx) { return values[varIdx][domainIdx] / Z; } /** * constructs a new array with the normalized distribution over values for a variable * @param varIdx index of the variable whose distribution to generate * @return */ public double[] getDistribution(int varIdx) { double[] ret = new double[values[varIdx].length]; for(int i = 0; i < ret.length; i++) ret[i] = values[varIdx][i] / Z; return ret; } public void print(PrintStream out) { for(int i = 0; i < values.length; i++) { printVariableDistribution(out, i); } } public abstract Integer getNumSamples(); public void printVariableDistribution(PrintStream out, int idx) { out.println(getVariableName(idx) + ":"); String[] domain = getDomain(idx); for(int j = 0; j < domain.length; j++) { double prob = values[idx][j] / Z; if(confidenceLevel == null) out.printf(" %.4f %s\n", prob, domain[j]); else { out.printf(" %.4f %s %s", prob, getConfidenceInterval(idx, j).toString()); } } } public ConfidenceInterval getConfidenceInterval(int varIdx, int domIdx) { return new ConfidenceInterval(varIdx, domIdx); } public abstract String getVariableName(int idx); public abstract int getVariableIndex(String name); public abstract String[] getDomain(int idx); public int getDomainSize(int idx) { return values[idx].length; } public GeneralSampledDistribution toGeneralDistribution() throws Exception { int numVars = values.length; String[] varNames = new String[numVars]; String[][] domains = new String[numVars][]; for(int i = 0; i < numVars; i++) { varNames[i] = getVariableName(i); domains[i] = getDomain(i); } return new GeneralSampledDistribution(this.values, this.Z, varNames, domains); } /** * gets the mean squared error when comparing to another distribution d, assuming that values of this distribution are correct * @param d the other distribution * @return the mean squared error (averaged across all entries of the distribution) * @throws Exception */ public double getMSE(BasicSampledDistribution d) throws Exception { return compare(new MeanSquaredError(this), d); } public double getHellingerDistance(BasicSampledDistribution d) throws Exception { return compare(new HellingerDistance(this), d); } public double compare(DistributionEntryComparison dec, BasicSampledDistribution otherDist) throws Exception { DistributionComparison dc = new DistributionComparison(this, otherDist); dc.addEntryComparison(dec); dc.compare(); return dec.getResult(); } public void setConfidenceLevel(Double confidenceLevel) { this.confidenceLevel = confidenceLevel; } public boolean usesConfidenceComputation() { return confidenceLevel != null; } public ParameterHandler getParameterHandler() { return paramHandler; } /** * compares two (sets of posterior marginal) distributions * @param mainDist * @param referenceDist * @param evidenceDomainIndices evidence domain indices array (indexed by variable indices in main distribution) * @throws Exception */ public static void compareDistributions(BasicSampledDistribution mainDist, BasicSampledDistribution referenceDist, int[] evidenceDomainIndices) throws Exception { DistributionComparison dc = new DistributionComparison(mainDist, referenceDist); dc.addEntryComparison(new ErrorList(mainDist)); dc.addEntryComparison(new MeanSquaredError(mainDist)); dc.addEntryComparison(new MeanAbsError(mainDist)); dc.addEntryComparison(new MaxAbsError(mainDist)); dc.addEntryComparison(new HellingerDistance(mainDist)); dc.compare(evidenceDomainIndices); dc.printResults(); } public class ConfidenceInterval { public double lowerEnd, upperEnd; protected int precisionDigits = 4; public ConfidenceInterval(int varIdx, int domIdx) { int numSamples = getNumSamples(); double p = values[varIdx][domIdx] / Z; double alpha = p * numSamples; double beta = numSamples - alpha; alpha += 1; beta += 1; double confAlpha = 1-confidenceLevel; lowerEnd = BetaDist.inverseF(alpha, beta, precisionDigits, confAlpha/2); upperEnd = BetaDist.inverseF(alpha, beta, precisionDigits, 1-confAlpha/2); if(p > upperEnd) { lowerEnd = BetaDist.inverseF(alpha, beta, precisionDigits, confAlpha); upperEnd = 1.0; } else if(p < lowerEnd) { lowerEnd = 0.0; upperEnd = BetaDist.inverseF(alpha, beta, precisionDigits, 1-confAlpha); } } public double getSize() { return upperEnd-lowerEnd; } public String toString() { return String.format(String.format("[%%.%df;%%.%df] %%.4f", precisionDigits, precisionDigits), lowerEnd, upperEnd, getSize()); } } public static class DistributionComparison { protected BasicSampledDistribution referenceDist, mainDist; protected Vector<DistributionEntryComparison> processors; public DistributionComparison(BasicSampledDistribution mainDist, BasicSampledDistribution referenceDist) { this.referenceDist = referenceDist; this.mainDist = mainDist; processors = new Vector<DistributionEntryComparison>(); } public void addEntryComparison(DistributionEntryComparison c) { processors.add(c); } public void addEntryComparison(Class<? extends DistributionEntryComparison> c) throws IllegalArgumentException, SecurityException, InstantiationException, IllegalAccessException, InvocationTargetException, NoSuchMethodException { addEntryComparison(c.getConstructor(BasicSampledDistribution.class).newInstance(referenceDist)); } /** * compare the (posterior marginal) distributions of the * non-evidence variables, i.e. variables whose domain indices * in evidenceDomainIndices are < 0 * @param evidenceDomainIndices evidence domain indices, indexed by variable index * @throws Exception */ public void compare(int[] evidenceDomainIndices) throws Exception { for(int i = 0; i < mainDist.values.length; i++) { if(evidenceDomainIndices != null && evidenceDomainIndices[i] >= 0) continue; String varName = mainDist.getVariableName(i); int i2 = referenceDist.getVariableIndex(varName); if(i2 < 0) throw new Exception("Variable " + varName + " has no correspondence in reference distribution"); for(int j = 0; j < mainDist.values[i].length; j++) { double v1 = referenceDist.getProbability(i2, j); double v2 = mainDist.getProbability(i, j); for(DistributionEntryComparison p : processors) p.process(i, j, mainDist.values[i].length, v1, v2); } } } public void compare() throws Exception { compare(null); } public void printResults() { for(DistributionEntryComparison dec : processors) dec.printResult(); } public double getResult(Class<? extends DistributionEntryComparison> c) throws Exception { for(DistributionEntryComparison p : processors) if(c.isInstance(p)) { return p.getResult(); } throw new Exception(c.getSimpleName() + " was not processed in this comparison"); } } public static abstract class DistributionEntryComparison { BasicSampledDistribution mainDist; public DistributionEntryComparison(BasicSampledDistribution refDist) { this.mainDist = refDist; } public abstract void process(int varIdx, int domIdx, int domSize, double p1, double p2); public abstract double getResult(); public void printResult() { System.out.printf("%s = %s\n", getClass().getSimpleName(), getResult()); } } public static class MeanSquaredError extends DistributionEntryComparison { double sum = 0.0; int cnt = 0; public MeanSquaredError(BasicSampledDistribution refDist) { super(refDist); } @Override public void process(int varIdx, int domIdx, int domSize, double p1, double p2) { ++cnt; double error = p1-p2; error *= error; sum += error; } @Override public double getResult() { return sum/cnt; } } public static class MeanAbsError extends DistributionEntryComparison { double sum = 0.0; int cnt = 0; public MeanAbsError(BasicSampledDistribution refDist) { super(refDist); } @Override public void process(int varIdx, int domIdx, int domSize, double p1, double p2) { ++cnt; double error = Math.abs(p1-p2); sum += error; } @Override public double getResult() { return sum/cnt; } } public static class MaxAbsError extends DistributionEntryComparison { double max = 0.0; public MaxAbsError(BasicSampledDistribution refDist) { super(refDist); } @Override public void process(int varIdx, int domIdx, int domSize, double p1, double p2) { double error = Math.abs(p1-p2); if(error > max) max = error; } @Override public double getResult() { return max; } } public static class HellingerDistance extends DistributionEntryComparison { double BhattacharyyaCoefficient = 0.0; double sum = 0.0; int numVars = 0; public HellingerDistance(BasicSampledDistribution refDist) { super(refDist); } @Override public void process(int varIdx, int domIdx, int domSize, double p1, double p2) { BhattacharyyaCoefficient += Math.sqrt(p1*p2); if(domIdx+1 == domSize) { numVars++; double Hellinger = Math.sqrt(1.0 - BhattacharyyaCoefficient); sum += Hellinger; BhattacharyyaCoefficient = 0; } } @Override public double getResult() { return sum /= numVars; } } public static class ErrorList extends DistributionEntryComparison { public ErrorList(BasicSampledDistribution refDist) { super(refDist); } @Override public void process(int varIdx, int domIdx, int domSize, double p1, double p2) { double error = p1 - p2; if(error != 0.0) { System.out.printf(" %s=%s: %f %f -> %f\n", mainDist.getVariableName(varIdx), mainDist.getDomain(varIdx)[domIdx], p1, p2, error); } } @Override public double getResult() { return 0; } @Override public void printResult() {} } }