/******************************************************************************* * Copyright (C) 2009-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.inference; import java.util.Vector; import probcog.inference.BasicSampledDistribution; import probcog.inference.IParameterHandler; import probcog.inference.ParameterHandler; import probcog.inference.BasicSampledDistribution.DistributionComparison; import probcog.inference.BasicSampledDistribution.DistributionEntryComparison; import probcog.inference.BasicSampledDistribution.MeanSquaredError; import edu.tum.cs.util.Stopwatch; public class TimeLimitedInference implements IParameterHandler { protected ITimeLimitedInference inference; protected double time, interval; protected InferenceThread thread; protected BasicSampledDistribution referenceDistribution = null; /** * mean-squared errors */ protected Vector<Double> MSEs = null; protected Vector<Class<? extends DistributionEntryComparison>> comparisonClasses; protected ParameterHandler paramHandler; protected boolean verbose = true; protected int[] evidenceDomainIndices = null; public TimeLimitedInference(ITimeLimitedInference inference, double time, double interval) throws Exception { this.inference = inference; this.time = time; this.interval = interval; comparisonClasses = new Vector<Class<? extends DistributionEntryComparison>>(); paramHandler = new ParameterHandler(this); paramHandler.add("verbose", "setVerbose"); } public void setVerbose(boolean verbose) { this.verbose = verbose; } public void setReferenceDistribution(BasicSampledDistribution dist) { referenceDistribution = dist; comparisonClasses.add(BasicSampledDistribution.MeanSquaredError.class); comparisonClasses.add(BasicSampledDistribution.HellingerDistance.class); MSEs = new Vector<Double>(); } public SampledDistribution run() throws Exception { // start the inference thread thread = new InferenceThread(); thread.start(); // wait, repeatedly polling intermediate results Stopwatch sw = new Stopwatch(); sw.start(); boolean useIntervals = true; if(!useIntervals) Thread.sleep((int)(1000*time)); else { int numSteps = (int)(time / interval); for(int i = 1; i <= numSteps && thread.isAlive(); i++) { Thread.sleep((int)(1000*interval)); if(verbose) System.out.printf("polling results after %fs (interval %d)...\n", sw.getElapsedTimeSecs(), i); SampledDistribution dist = pollResults(true); if(verbose && dist != null) System.out.printf("%d samples taken\n", dist.steps); if(referenceDistribution != null) { double mse; if(dist == null) mse = Double.POSITIVE_INFINITY; else { DistributionComparison dc = doComparison(dist); mse = dc.getResult(MeanSquaredError.class); } MSEs.add(mse); } } } // get final results, terminating the inference thread if it is still running SampledDistribution results = pollResults(false); if(thread.isAlive()) thread.stop(); return results; } /** * sets evidence domain indices for distribution comparison (in order to be able to ignore * evidence variables in the comparisons) * @param evidenceDomainIndices */ public void setEvidenceDomainIndices(int[] evidenceDomainIndices) { this.evidenceDomainIndices = evidenceDomainIndices; } protected DistributionComparison doComparison(BasicSampledDistribution dist) throws Exception { DistributionComparison dc = new DistributionComparison(this.referenceDistribution, dist); for(Class<? extends DistributionEntryComparison> c : comparisonClasses) dc.addEntryComparison(c); dc.compare(this.evidenceDomainIndices); dc.printResults(); return dc; } public SampledDistribution pollResults(boolean allowPrint) throws Exception { SampledDistribution dist = thread.pollResults(); if(allowPrint && verbose && dist != null) printResults(dist); return dist; } protected void printResults(SampledDistribution dist) { // TODO } /** * returns the mean squared errors collected after each interval * @return */ public Vector<Double> getMSEs() { return MSEs; } protected class InferenceThread extends Thread { public void run() { try { inference.infer(); } catch(Exception e) { throw new RuntimeException(e); } } public SampledDistribution pollResults() throws Exception { return inference.pollResults(); } } @Override public ParameterHandler getParameterHandler() { return paramHandler; } }