/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.bayesian.query.enumerate;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.encog.Encog;
import org.encog.ml.bayesian.BayesianError;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.EventType;
import org.encog.ml.bayesian.query.BasicQuery;
import org.encog.ml.bayesian.query.sample.EventState;
import org.encog.ml.bayesian.table.TableLine;
import org.encog.util.Format;
/**
* An enumeration query allows probabilistic queries on a Bayesian network.
* Enumeration works by calculating every combination of hidden nodes and using
* total probability. This results in an accurate deterministic probability.
* However, enumeration can be slow for large Bayesian networks. For a quick
* estimate of probability the sampling query can be used.
*
*/
public class EnumerationQuery extends BasicQuery implements Serializable {
/**
* The events that we will enumerate over.
*/
private List<EventState> enumerationEvents = new ArrayList<EventState>();
/**
* The calculated probability.
*/
private double probability;
/**
* Construct the enumeration query.
*
* @param theNetwork
* The Bayesian network to query.
*/
public EnumerationQuery(BayesianNetwork theNetwork) {
super(theNetwork);
}
/**
* Default constructor.
*/
public EnumerationQuery() {
}
/**
* Reset the enumeration events. Always reset the hidden events. Optionally
* reset the evidence and outcome.
*
* @param includeEvidence
* True if the evidence is to be reset.
* @param includeOutcome
* True if the outcome is to be reset.
*/
public void resetEnumeration(boolean includeEvidence, boolean includeOutcome) {
this.enumerationEvents.clear();
for (EventState state : this.getEvents().values()) {
if (state.getEventType() == EventType.Hidden) {
this.enumerationEvents.add(state);
state.setValue(0);
} else if (includeEvidence
&& state.getEventType() == EventType.Evidence) {
this.enumerationEvents.add(state);
state.setValue(0);
} else if (includeOutcome
&& state.getEventType() == EventType.Outcome) {
this.enumerationEvents.add(state);
state.setValue(0);
} else {
state.setValue(state.getCompareValue());
}
}
}
/**
* Roll the enumeration events forward by one.
*
* @return False if there are no more values to roll into, which means we're
* done.
*/
public boolean forward() {
int currentIndex = 0;
boolean done = false;
boolean eof = false;
if( this.enumerationEvents.size() == 0 ) {
done = true;
eof = true;
}
while (!done) {
EventState state = this.enumerationEvents.get(currentIndex);
int v = (int) state.getValue();
v++;
if (v >= state.getEvent().getChoices().size()) {
state.setValue(0);
} else {
state.setValue(v);
done = true;
break;
}
currentIndex++;
if (currentIndex >= this.enumerationEvents.size()) {
done = true;
eof = true;
}
}
return !eof;
}
/**
* Obtain the arguments for an event.
* @param event The event.
* @return The arguments.
*/
private int[] obtainArgs(BayesianEvent event) {
int[] result = new int[event.getParents().size()];
int index = 0;
for (BayesianEvent parentEvent : event.getParents()) {
EventState state = this.getEventState(parentEvent);
result[index++] = state.getValue();
}
return result;
}
/**
* Calculate the probability for a state.
* @param state The state to calculate.
* @return The probability.
*/
private double calculateProbability(EventState state) {
int[] args = obtainArgs(state.getEvent());
for (TableLine line : state.getEvent().getTable().getLines()) {
if (line.compareArgs(args)) {
if (Math.abs(line.getResult() - state.getValue()) < Encog.DEFAULT_DOUBLE_EQUAL) {
return line.getProbability();
}
}
}
throw new BayesianError("Could not determine the probability for "
+ state.toString());
}
/**
* Perform a single enumeration.
* @return The result.
*/
private double performEnumeration() {
double result = 0;
do {
boolean first = true;
double prob = 0;
for (EventState state : this.getEvents().values()) {
if (first) {
prob = calculateProbability(state);
first = false;
} else {
prob *= calculateProbability(state);
}
}
result += prob;
} while (forward());
return result;
}
/**
* {@inheritDoc}
*/
public void execute() {
locateEventTypes();
resetEnumeration(false, false);
double numerator = performEnumeration();
resetEnumeration(false, true);
double denominator = performEnumeration();
this.probability = numerator / denominator;
}
/**
* {@inheritDoc}
*/
public double getProbability() {
return probability;
}
/**
* {@inheritDoc}
*/
public String toString() {
StringBuilder result = new StringBuilder();
result.append("[SamplingQuery: ");
result.append(getProblem());
result.append("=");
result.append(Format.formatPercent(getProbability()));
result.append("]");
return result.toString();
}
/**
* Roll the enumeration events forward by one.
* @param enumerationEvents The enumeration events.
* @param args The arguments.
* @return False if there are no more values to roll into, which means we're
* done.
*/
public static boolean roll(List<BayesianEvent> enumerationEvents, int[] args) {
int currentIndex = 0;
boolean done = false;
boolean eof = false;
if( enumerationEvents.size() == 0 ) {
done = true;
eof = true;
}
while (!done) {
BayesianEvent event = enumerationEvents.get(currentIndex);
int v = (int) args[currentIndex];
v++;
if (v >= event.getChoices().size()) {
args[currentIndex] = 0;
} else {
args[currentIndex] = v;
done = true;
break;
}
currentIndex++;
if (currentIndex >= args.length) {
done = true;
eof = true;
}
}
return !eof;
}
/**
* @return A clone of this object.
*/
public EnumerationQuery clone() {
return new EnumerationQuery(this.getNetwork());
}
}