/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.bayesian; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.encog.ml.BasicML; import org.encog.ml.MLClassification; import org.encog.ml.MLError; import org.encog.ml.MLResettable; import org.encog.ml.bayesian.parse.ParseProbability; import org.encog.ml.bayesian.parse.ParsedChoice; import org.encog.ml.bayesian.parse.ParsedEvent; import org.encog.ml.bayesian.parse.ParsedProbability; import org.encog.ml.bayesian.query.BayesianQuery; import org.encog.ml.bayesian.query.enumerate.EnumerationQuery; import org.encog.ml.bayesian.query.sample.EventState; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.util.EngineArray; import org.encog.util.csv.CSVFormat; /** * The Bayesian Network is a machine learning method that is based on * probability, and particularly Bayes' Rule. The Bayesian Network also forms * the basis for the Hidden Markov Model and Naive Bayesian Network. The * Bayesian Network is either constructed directly or inferred from training * data using an algorithm such as K2. * * http://www.heatonresearch.com/wiki/Bayesian_Network */ public class BayesianNetwork extends BasicML implements MLClassification, MLResettable, Serializable, MLError { /** * Default choices for a boolean event. */ public static final String[] CHOICES_TRUE_FALSE = { "true", "false" }; /** * Mapping between the event string names, and the actual events. */ private final Map<String, BayesianEvent> eventMap = new HashMap<String, BayesianEvent>(); /** * A listing of all of the events. */ private final List<BayesianEvent> events = new ArrayList<BayesianEvent>(); /** * The current Bayesian query. */ private BayesianQuery query; /** * Specifies if each input is present. */ private boolean[] inputPresent; /** * Specifies the classification target. */ private int classificationTarget; /** * The probabilities of each classification. */ private double[] classificationProbabilities; public BayesianNetwork() { this.query = new EnumerationQuery(this); } /** * @return The mapping from string names to events. */ public Map<String, BayesianEvent> getEventMap() { return eventMap; } /** * @return The events. */ public List<BayesianEvent> getEvents() { return this.events; } /** * Get an event based on the string label. * @param label The label to locate. * @return The event found. */ public BayesianEvent getEvent(String label) { return this.eventMap.get(label); } /** * Get an event based on label, throw an error if not found. * @param label THe event label to find. * @return The event. */ public BayesianEvent getEventError(String label) { if (!eventExists(label)) throw (new BayesianError("Undefined label: " + label)); return this.eventMap.get(label); } /** * Return true if the specified event exists. * @param label The label we are searching for. * @return True, if the event exists by label. */ public boolean eventExists(String label) { return this.eventMap.containsKey(label); } /** * Create, or register, the specified event with this bayesian network. * @param event The event to add. */ public void createEvent(BayesianEvent event) { if( eventExists(event.getLabel())) { throw new BayesianError("The label \"" + event.getLabel() + "\" has already been defined."); } this.eventMap.put(event.getLabel(), event); this.events.add(event); } /** * Create an event specified on the label and options provided. * @param label The label to create this event as. * @param options The options, or states, that this event can have. * @return The newly created event. */ public BayesianEvent createEvent(String label, List<BayesianChoice> options) { if( label==null) { throw new BayesianError("Can't create event with null label name"); } if (eventExists(label)) { throw new BayesianError("The label \"" + label + "\" has already been defined."); } BayesianEvent event; if( options.size()==0 ) { event = new BayesianEvent(label); } else { event = new BayesianEvent(label,options); } createEvent(event); return event; } /** * Create the specified events based on a variable number of options, or choices. * @param label The label of the event to create. * @param options The states that the event can have. * @return The newly created event. */ public BayesianEvent createEvent(String label, String ... options) { if( label==null) { throw new BayesianError("Can't create event with null label name"); } if (eventExists(label)) { throw new BayesianError("The label \"" + label + "\" has already been defined."); } BayesianEvent event; if( options.length==0 ) { event = new BayesianEvent(label); } else { event = new BayesianEvent(label,options); } createEvent(event); return event; } /** * Create a dependency between two events. * @param parentEvent The parent event. * @param childEvent The child event. */ public void createDependency(BayesianEvent parentEvent, BayesianEvent childEvent) { // does the dependency exist? if(!hasDependency(parentEvent,childEvent) ) { // create the dependency parentEvent.addChild(childEvent); childEvent.addParent(parentEvent); } } /** * Determine if the two events have a dependency. * @param parentEvent The parent event. * @param childEvent The child event. * @return True if a dependency exists. */ private boolean hasDependency(BayesianEvent parentEvent, BayesianEvent childEvent) { return( parentEvent.getChildren().contains(childEvent)); } /** * Create a dependency between a parent and multiple children. * @param parentEvent The parent event. * @param children The child events. */ public void createDependency(BayesianEvent parentEvent, BayesianEvent... children) { for (BayesianEvent childEvent : children) { parentEvent.addChild(childEvent); childEvent.addParent(parentEvent); } } /** * Create a dependency between two labels. * @param parentEventLabel The parent event. * @param childEventLabel The child event. */ public void createDependency(String parentEventLabel, String childEventLabel) { BayesianEvent parentEvent = getEventError(parentEventLabel); BayesianEvent childEvent = getEventError(childEventLabel); createDependency(parentEvent, childEvent); } /** * @return The contents as a string. Shows both events and dependences. */ public String getContents() { StringBuilder result = new StringBuilder(); boolean first = true; for (BayesianEvent e : this.events) { if (!first) result.append(" "); first = false; result.append(e.toFullString()); } return result.toString(); } /** * Define the structure of the Bayesian network as a string. * @param line The string to define events and relations. */ public void setContents(String line) { List<ParsedProbability> list = ParseProbability.parseProbabilityList(this, line); List<String> labelList = new ArrayList<String>(); // ensure that all events are there for(ParsedProbability prob: list ) { ParsedEvent parsedEvent = prob.getChildEvent(); String eventLabel = parsedEvent.getLabel(); labelList.add(eventLabel); // create event, if not already here BayesianEvent e = getEvent(eventLabel); if( e==null ) { List<BayesianChoice> cl = new ArrayList<BayesianChoice>(); for( ParsedChoice c : parsedEvent.getList() ) { cl.add(new BayesianChoice(c.getLabel(),c.getMin(),c.getMax())); } createEvent(eventLabel, cl); } } // now remove all events that were not covered for(int i=0; i<events.size();i++) { BayesianEvent event = this.events.get(i); if( !labelList.contains(event.getLabel()) ) { removeEvent(event); } } // handle dependencies for(ParsedProbability prob: list ) { ParsedEvent parsedEvent = prob.getChildEvent(); String eventLabel = parsedEvent.getLabel(); BayesianEvent event = requireEvent(eventLabel); // ensure that all "givens" are present List<String> givenList = new ArrayList<String>(); for( ParsedEvent given: prob.getGivenEvents() ) { if( !event.hasGiven(given.getLabel()) ) { BayesianEvent givenEvent = requireEvent(given.getLabel()); this.createDependency(givenEvent, event); } givenList.add(given.getLabel()); } // now remove givens that were not covered for(int i=0; i<event.getParents().size();i++) { BayesianEvent event2 = event.getParents().get(i); if( !givenList.contains(event2.getLabel()) ) { removeDependency(event2,event); } } } // finalize the structure finalizeStructure(); if (this.query != null) { this.query.finalizeStructure(); } } /** * Remove a dependency, if it it exists. * @param parent The parent event. * @param child The child event. */ private void removeDependency(BayesianEvent parent, BayesianEvent child) { parent.getChildren().remove(child); child.getParents().remove(parent); } /** * Remove the specified event. * @param event The event to remove. */ private void removeEvent(BayesianEvent event) { for( BayesianEvent e : event.getParents() ) { e.getChildren().remove(event); } this.eventMap.remove(event.getLabel()); this.events.remove(event); } /** * {@inheritDoc} */ public String toString() { StringBuilder result = new StringBuilder(); boolean first = true; for (BayesianEvent e : this.events) { if (!first) result.append(" "); first = false; result.append(e.toString()); } return result.toString(); } /** * @return The number of parameters in this Bayesian network. */ public int calculateParameterCount() { int result = 0; for (BayesianEvent e : this.eventMap.values()) { result += e.calculateParameterCount(); } return result; } /** * Finalize the structure of this Bayesian network. */ public void finalizeStructure() { for (BayesianEvent e : this.eventMap.values()) { e.finalizeStructure(); } if( this.query!=null ) { this.query.finalizeStructure(); } this.inputPresent = new boolean[this.events.size()]; EngineArray.fill(this.inputPresent, true); this.classificationTarget = -1; } /** * Validate the structure of this Bayesian network. */ public void validate() { for (BayesianEvent e : this.eventMap.values()) { e.validate(); } } /** * Determine if one Bayesian event is in an array of others. * @param given The events to check. * @param e See if e is amoung given. * @return True if e is amoung given. */ private boolean isGiven(BayesianEvent[] given, BayesianEvent e) { for (BayesianEvent e2 : given) { if (e == e2) return true; } return false; } /** * Determine if one event is a descendant of another. * @param a The event to check. * @param b The event that has children. * @return True if a is amoung b's children. */ public boolean isDescendant(BayesianEvent a, BayesianEvent b) { if (a == b) return true; for (BayesianEvent e : b.getChildren()) { if (isDescendant(a, e)) return true; } return false; } /** * True if this event is given or conditionally dependent on the others. * @param given The others to check. * @param e The event to check. * @return True, if the node is given or descendant. */ private boolean isGivenOrDescendant(BayesianEvent[] given, BayesianEvent e) { for (BayesianEvent e2 : given) { if (isDescendant(e2, e)) return true; } return false; } /** * Help determine if one event is conditionally independent of another. * @param previousHead The previous head, as we traverse the list. * @param a The event to check. * @param goal The goal. * @param searched List of events searched. * @param given Given events. * @return True if conditionally independent. */ private boolean isCondIndependent(boolean previousHead, BayesianEvent a, BayesianEvent goal, Set<BayesianEvent> searched, BayesianEvent... given) { // did we find it? if (a == goal) { return false; } // search children for (BayesianEvent e : a.getChildren()) { if (!searched.contains(e) || !isGiven(given, a)) { searched.add(e); if (!isCondIndependent(true, e, goal, searched, given)) return false; } } // search parents for (BayesianEvent e : a.getParents()) { if (!searched.contains(e)) { searched.add(e); if (!previousHead || isGivenOrDescendant(given, a)) if (!isCondIndependent(false, e, goal, searched, given)) return false; } } return true; } public boolean isCondIndependent(BayesianEvent a, BayesianEvent b, BayesianEvent... given) { Set<BayesianEvent> searched = new HashSet<BayesianEvent>(); return isCondIndependent(false, a, b, searched, given); } public BayesianQuery getQuery() { return query; } public void setQuery(BayesianQuery query) { this.query = query; } /** * {@inheritDoc} */ @Override public int getInputCount() { return this.events.size(); } /** * {@inheritDoc} */ @Override public int getOutputCount() { return 1; } public double computeProbability(MLData input) { // copy the input to evidence int inputIndex = 0; for (int i = 0; i < this.events.size(); i++) { BayesianEvent event = this.events.get(i); EventState state = this.query.getEventState(event); if (state.getEventType() == EventType.Evidence) { state.setValue((int)input.getData(inputIndex++)); } } // execute the query this.query.execute(); return this.query.getProbability(); } /** * Define the probability for an event. * @param line The event. * @param probability The probability. */ public void defineProbability(String line, double probability) { ParseProbability parse = new ParseProbability(this); ParsedProbability parsedProbability = parse.parse(line); parsedProbability.defineTruthTable(this, probability); } /** * Define a probability. * @param line The line to define the probability. */ public void defineProbability(String line) { int index = line.lastIndexOf('='); boolean error = false; double prob = 0.0; String left = ""; String right = ""; if (index != -1) { left = line.substring(0, index); right = line.substring(index + 1); try { prob = CSVFormat.EG_FORMAT.parse(right); } catch (NumberFormatException ex) { error = true; } } if( error || index==-1) { throw new BayesianError("Probability must be of the form \"P(event|condition1,condition2,etc.)=0.5\". Conditions are optional."); } defineProbability(left,prob); } /** * Require the specified event, thrown an error if it does not exist. * @param label The label. * @return The event. */ public BayesianEvent requireEvent(String label) { BayesianEvent result = getEvent(label); if( result==null ) { throw new BayesianError("The event " + label + " is not defined."); } return result; } /** * Define a relationship. * @param line The relationship to define. */ public void defineRelationship(String line) { ParseProbability parse = new ParseProbability(this); ParsedProbability parsedProbability = parse.parse(line); parsedProbability.defineRelationships(this); } /** * Perform a query. * @param line The query. * @return The probability. */ public double performQuery(String line) { if( this.query==null ) { throw new BayesianError("This Bayesian network does not have a query to define."); } ParseProbability parse = new ParseProbability(this); ParsedProbability parsedProbability = parse.parse(line); // create a temp query BayesianQuery q = this.query.clone(); // first, mark all events as hidden q.reset(); // deal with evidence (input) for( ParsedEvent parsedEvent : parsedProbability.getGivenEvents() ) { BayesianEvent event = this.requireEvent(parsedEvent.getLabel()); q.defineEventType(event, EventType.Evidence); q.setEventValue(event, parsedEvent.resolveValue(event)); } // deal with outcome (output) for( ParsedEvent parsedEvent : parsedProbability.getBaseEvents() ) { BayesianEvent event = requireEvent(parsedEvent.getLabel()); q.defineEventType(event, EventType.Outcome); q.setEventValue(event, parsedEvent.resolveValue(event)); } q.locateEventTypes(); q.execute(); return q.getProbability(); } /** * {@inheritDoc} */ @Override public void updateProperties() { // Not needed } public int getEventIndex(BayesianEvent event) { for(int i=0;i<this.events.size();i++) { if( event==events.get(i)) return i; } return -1; } /** * Remove all relations between nodes. */ public void removeAllRelations() { for(BayesianEvent event: this.events) { event.removeAllRelations(); } } /** * {@inheritDoc} */ @Override public void reset() { reset(0); } /** * {@inheritDoc} */ @Override public void reset(int seed) { for(BayesianEvent event: this.events) { event.reset(); } } /** * Determine the classes for the specified input. * @param input The input. * @return An array of class indexes. */ public int[] determineClasses(MLData input) { int[] result = new int[input.size()]; for(int i=0;i<input.size();i++) { BayesianEvent event = this.events.get(i); int classIndex = event.matchChoiceToRange(input.getData(i)); result[i] = classIndex; } return result; } /** * Classify the input. * @param input The input to classify. */ @Override public int classify(MLData input) { if( this.classificationTarget<0 || this.classificationTarget>=this.events.size() ) { throw new BayesianError("Must specify classification target by calling setClassificationTarget."); } int[] d = this.determineClasses(input); // properly tag all of the events for(int i=0;i<this.events.size();i++) { BayesianEvent event = this.events.get(i); if( i==this.classificationTarget ) { this.query.defineEventType(event, EventType.Outcome); } else if( this.inputPresent[i] ) { this.query.defineEventType(event, EventType.Evidence); this.query.setEventValue(event, d[i]); } else { this.query.defineEventType(event, EventType.Hidden); this.query.setEventValue(event, d[i]); } } // loop over and try each outcome choice BayesianEvent outcomeEvent = this.events.get(this.classificationTarget); this.classificationProbabilities = new double[outcomeEvent.getChoices().size()]; for(int i=0;i<outcomeEvent.getChoices().size();i++) { this.query.setEventValue(outcomeEvent, i); this.query.execute(); classificationProbabilities[i] = this.query.getProbability(); } return EngineArray.maxIndex(this.classificationProbabilities); } /** * Get the classification target. * @return The index of the classification target. */ public int getClassificationTarget() { return classificationTarget; } /** * Determine if the specified input is present. * @param idx The index of the input. * @return True, if the input is present. */ public boolean isInputPresent(int idx) { return this.inputPresent[idx]; } /** * Define a classification structure of the form P(A|B) = P(C) * @param line The line. */ public void defineClassificationStructure(String line) { List<ParsedProbability> list = ParseProbability.parseProbabilityList(this, line); if( list.size()>1) { throw new BayesianError("Must only define a single probability, not a chain."); } if( list.size()==0) { throw new BayesianError("Must define at least one probability."); } // first define everything to be hidden for(BayesianEvent event: this.events) { this.query.defineEventType(event, EventType.Hidden); } // define the base event ParsedProbability prob = list.get(0); if( prob.getBaseEvents().size()==0 ) { return; } BayesianEvent be = this.getEvent( prob.getChildEvent().getLabel() ); this.classificationTarget = this.events.indexOf(be); this.query.defineEventType(be, EventType.Outcome); // define the given events for(ParsedEvent parsedGiven: prob.getGivenEvents()) { BayesianEvent given = this.getEvent( parsedGiven.getLabel() ); this.query.defineEventType(given, EventType.Evidence); } this.query.locateEventTypes(); // set the values for(ParsedEvent parsedGiven: prob.getGivenEvents()) { BayesianEvent event = this.getEvent( parsedGiven.getLabel() ); this.query.setEventValue(event, parseInt(parsedGiven.getValue()) ); } this.query.setEventValue(be, parseInt(prob.getBaseEvents().get(0).getValue()) ); } private int parseInt(String str) { if( str==null ) { return 0; } try { return Integer.parseInt(str); } catch(NumberFormatException ex) { return 0; } } /** * @return The classification target. */ public BayesianEvent getClassificationTargetEvent() { if( this.classificationTarget==-1) { throw new BayesianError("No classification target defined."); } return this.events.get(this.classificationTarget); } /** * {@inheritDoc} */ @Override public double calculateError(final MLDataSet data) { if( !this.hasValidClassificationTarget()) return 1.0; // do the following just to throw an error if there is no classification target getClassificationTarget(); int badCount = 0; int totalCount = 0; for(MLDataPair pair: data) { int c = this.classify(pair.getInput()); totalCount++; if( c!=pair.getInput().getData(this.classificationTarget)) { badCount++; } } return (double)badCount/(double)totalCount; } /** * @return Returns a string representation of the classification structure. * Of the form P(a|b,c,d) */ public String getClassificationStructure() { StringBuilder result = new StringBuilder(); result.append("P("); boolean first = true; for(int i=0;i<this.getEvents().size();i++) { BayesianEvent event = this.events.get(i); EventState state = this.query.getEventState(event); if( state.getEventType()==EventType.Outcome ) { if(!first) { result.append(","); } result.append(event.getLabel()); first = false; } } result.append("|"); first = true; for(int i=0;i<this.getEvents().size();i++) { BayesianEvent event = this.events.get(i); if( this.query.getEventState(event).getEventType()==EventType.Evidence ) { if(!first) { result.append(","); } result.append(event.getLabel()); first = false; } } result.append(")"); return result.toString(); } /** * @return True if this network has a valid classification target. */ public boolean hasValidClassificationTarget() { if (this.classificationTarget < 0 || this.classificationTarget >= this.events.size()) { return false; } else { return true; } } }