/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.bayesian.query.sample;
import java.io.Serializable;
import org.encog.ml.bayesian.BayesianError;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.query.BasicQuery;
import org.encog.util.Format;
/**
* A sampling query allows probabilistic queries on a Bayesian network. Sampling
* works by actually simulating the probabilities using a random number
* generator. A sample size must be specified. The higher the sample size, the
* more accurate the probability will be. However, the higher the sampling size,
* the longer it takes to run the query.
*
* An enumeration query is more precise than the sampling query. However, the
* enumeration query will become slow as the size of the Bayesian network grows.
* Sampling can often be used for a quick estimation of a probability.
*/
public class SamplingQuery extends BasicQuery implements Serializable {
/**
* The default sample size.
*/
public static final int DEFAULT_SAMPLE_SIZE = 100000;
/**
* The sample size.
*/
private int sampleSize = DEFAULT_SAMPLE_SIZE;
/**
* The number of usable samples. This is the set size for the average
* probability.
*/
private int usableSamples;
/**
* The number of samples that matched the result the query is looking for.
*/
private int goodSamples;
/**
* The total number of samples generated. This should match sampleSize at
* the end of a query.
*/
private int totalSamples;
/**
* Construct a sampling query.
* @param theNetwork The network that will be queried.
*/
public SamplingQuery(BayesianNetwork theNetwork) {
super(theNetwork);
}
/**
* @return the sampleSize
*/
public int getSampleSize() {
return sampleSize;
}
/**
* @param sampleSize
* the sampleSize to set
*/
public void setSampleSize(int sampleSize) {
this.sampleSize = sampleSize;
}
/**
* Obtain the arguments for an event.
* @param event The event.
* @return The arguments for that event, based on the other event values.
*/
private int[] obtainArgs(BayesianEvent event) {
int[] result = new int[event.getParents().size()];
int index = 0;
for (BayesianEvent parentEvent : event.getParents()) {
EventState state = this.getEventState(parentEvent);
if (!state.isCalculated())
return null;
result[index++] = state.getValue();
}
return result;
}
/**
* Set all events to random values, based on their probabilities.
* @param eventState
*/
private void randomizeEvents(EventState eventState) {
// first, has this event already been randomized
if (!eventState.isCalculated()) {
// next, see if we can randomize the event passed
int[] args = obtainArgs(eventState.getEvent());
if (args != null) {
eventState.randomize(args);
}
}
// randomize children
for (BayesianEvent childEvent : eventState.getEvent().getChildren()) {
randomizeEvents(getEventState(childEvent));
}
}
/**
* @return The number of events that are still uncalculated.
*/
private int countUnCalculated() {
int result = 0;
for (EventState state : getEvents().values()) {
if (!state.isCalculated())
result++;
}
return result;
}
/**
* {@inheritDoc}
*/
public void execute() {
locateEventTypes();
this.usableSamples = 0;
this.goodSamples = 0;
this.totalSamples = 0;
for (int i = 0; i < this.sampleSize; i++) {
reset();
int lastUncalculated = Integer.MAX_VALUE;
int uncalculated;
do {
for (EventState state : getEvents().values()) {
randomizeEvents(state);
}
uncalculated = countUnCalculated();
if (uncalculated == lastUncalculated) {
throw new BayesianError(
"Unable to calculate all nodes in the graph.");
}
lastUncalculated = uncalculated;
} while (uncalculated > 0);
// System.out.println("Sample:\n" + this.dumpCurrentState());
this.totalSamples++;
if (isNeededEvidence()) {
this.usableSamples++;
if (satisfiesDesiredOutcome()) {
this.goodSamples++;
}
}
}
}
/**
* {@inheritDoc}
*/
public double getProbability() {
return (double) this.goodSamples / (double) this.usableSamples;
}
/**
* @return The current state as a string.
*/
public String dumpCurrentState() {
StringBuilder result = new StringBuilder();
for (EventState state : getEvents().values()) {
result.append(state.toString());
result.append("\n");
}
return result.toString();
}
public SamplingQuery clone() {
return new SamplingQuery(this.getNetwork());
}
/**
* {@inheritDoc}
*/
public String toString() {
StringBuilder result = new StringBuilder();
result.append("[SamplingQuery: ");
result.append(getProblem());
result.append("=");
result.append(Format.formatPercent(getProbability()));
result.append(" ;good/usable=");
result.append(Format.formatInteger(this.goodSamples));
result.append("/");
result.append(Format.formatInteger(this.usableSamples));
result.append(";totalSamples=");
result.append(Format.formatInteger(this.totalSamples));
return result.toString();
}
}