/**
* This code is terrible and probably contains multiple exponential blowups and I apologize to anyone reading this.
*
* Note that some of the seeming unnecessary complexity in this class is in preparation for the possibility of
* implementing junction trees. Since this comment was originally written, I've forgotten which parts, however.
*/
package bayesGame.bayesbayes;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.fraction.Fraction;
import org.apache.commons.math3.util.Pair;
public class BayesNode {
public final Object type;
protected Object[] scope;
protected HashMap<Object,Integer> strides = new HashMap<Object,Integer>();
private Fraction[] cpt;
private Fraction[] potential;
public String cptDescription;
public String cptName;
// getProbability returns the contents of this variable, but it isn't actually used for anything
// inside the node: it's just a cache.
private Fraction probability;
private HashSet<Message> upstreamMessages = new HashSet<Message>();
private HashSet<Message> downstreamMessages = new HashSet<Message>();
private boolean observed = false;
private Boolean assumedValue = null;
private Boolean trueValue = null;
private Set<String> properties = new HashSet<String>();
protected BayesNode(Object type){
this.type = type;
this.scope = new Object[]{type};
this.cpt = this.createRawFractionArray(this.scope);
this.strides = this.createStridesFromScope(this.scope);
this.cptDescription = this.createCPTDescription(this.scope);
this.cptName = this.createCPTType(this.scope);
this.potential = this.cpt.clone();
}
protected BayesNode(Object type, Object[] scope){
this(type, scope, null, null);
}
protected BayesNode(Object type, Object[] scope, HashMap<Object,Integer> strides, Fraction[] cpt){
this.type = type;
// the scope of a node must contain its own type
if (!Arrays.asList(scope).contains(type)){
// if one is missing, try to add the type of the node to the scope
if ((strides == null) && (cpt == null)){
scope = this.copyArrayAddingItem(scope, type);
} else {
throw new IllegalArgumentException("The scope of a node must contain its own type");
}
}
this.scope = scope;
if (strides != null){
this.strides = strides;
} else {
this.strides = this.createStridesFromScope(this.scope);
}
if (cpt != null){
this.cpt = copyFraction(cpt);
} else {
this.cpt = createRawFractionArray(scope);
}
this.potential = copyFraction(this.cpt);
}
private Object[] copyArrayAddingItem(Object [] array, Object item){
Object[] newscope = new Object[array.length + 1];
newscope[0] = item;
int location = 1;
for (Object o : array){
newscope[location] = o;
location++;
}
return newscope;
}
protected boolean addItemToScope(Object item){
if (Arrays.asList(scope).contains(item)){
return false;
}
scope = copyArrayAddingItem(scope, item);
cpt = this.createRawFractionArray(scope);
strides = this.createStridesFromScope(scope);
resetPotential();
return true;
}
protected void updateProbability(){
Fraction[] probabilities = this.getNormalizedMarginalPotential(type);
probability = probabilities[0];
}
public Fraction getProbability(){
if (probability == null){
updateProbability();
}
return copyFraction(probability);
}
protected void setProbability(Fraction probability){
this.probability = probability;
}
public ArrayList<Map<Object,Boolean>> getNonZeroProbabilities(){
indexChooser chooser = new indexChooser();
ArrayList<Map<Object,Boolean>> truthValues = new ArrayList<Map<Object,Boolean>>();
// map<object,boolean> items = the truth values of single items in a p > 0 row
// list of maps = the whole thing
for (int i = 0; i < potential.length; i++){
Fraction f = potential[i];
if (f.doubleValue() > 0.00d){
ArrayList<Boolean> valuesAtIndex = chooser.getTruthValues(i);
Map<Object,Boolean> row = new HashMap<Object,Boolean>(valuesAtIndex.size());
for (int j = 0; j < valuesAtIndex.size(); j++){
Object o = scope[j];
row.put(o, valuesAtIndex.get(j));
}
truthValues.add(row);
}
}
return truthValues;
}
private Fraction[] createRawFractionArray(Object[] scope){
Fraction[] array = new Fraction[(int) Math.pow(2, scope.length)];
Fraction fraction = new Fraction(1, array.length);
Arrays.fill(array, fraction);
return array;
}
private HashMap<Object,Integer> createStridesFromScope(Object[] scope){
HashMap<Object,Integer> stride = new HashMap<Object,Integer>();
int i = 1;
for (Object o : scope){
stride.put(o, i);
i = i * 2;
}
return stride;
}
private String createCPTDescription(Object[] scope) {
if (scope.length == 1){
return "'" + this.type + "' is a prior variable. The truth values of any of its child variables are derived from it, as well as from any other parent variables.";
} else {
return "'" + this.type + "' is a conditional probability variable of type <b>custom distribution.";
}
}
private String createCPTType(Object[] scope){
if (scope.length == 1){
return "Prior";
} else {
return "Custom";
}
}
public void setTrueValue(boolean value){
trueValue = value;
}
/**
* Resets the node's potential to the initial CPT, clearing any changes from messages,
* setting the node as unobserved, and clearing any assumed values. To reset the node's
* potential while maintaining its status as observed, use resetPotential instead.
*/
public void resetNode(){
observed = false;
potential = copyFraction(cpt);
probability = null;
assumedValue = null;
}
/**
* Resets the node's potential to the initial CPT, clearing any changes from messages.
* If the node has been observed or assumed, it remains so, with corresponding effects
* to the CPT.
* To reset the node's observation status as well, use resetNode instead.
*/
public void resetPotential(){
potential = copyFraction(cpt);
if (observed){
observe();
} else if (assumedValue != null){
assumeValue(assumedValue);
}
probability = null;
}
public boolean isObserved(){
return observed;
}
public boolean isAssumed(){
if (assumedValue == null){
return false;
} else {
return true;
}
}
public boolean setProbabilityOfUntrueVariables(Fraction probability, Object... variables){
probability = copyFraction(probability);
if (!checkCPTInputForValidity(variables)){
return false;
}
if (probability.doubleValue() > 1.0d || probability.doubleValue() < 0.0d){
return false;
}
indexChooser selfChecker = new indexChooser();
selfChecker.requestUntrue(variables);
int index = selfChecker.getIndex();
cpt[index] = probability;
potential[index] = copyFraction(probability);
this.probability = null;
return true;
}
private boolean checkCPTInputForValidity(Object[] variables){
if (variables.length > scope.length){
return false;
}
List<Object> scopeList = Arrays.asList(scope);
for (Object o : variables){
if (!scopeList.contains(o)){
return false;
}
}
return true;
}
/**
* Observes the node, setting its probabilities according to its true value.
* If the true value has not been set, randomly generates it based on the
* current probabilities. Note that network is responsible for updating the
* probabilities of any adjacent nodes afterwards.
*/
public void observe(){
boolean assumedCleared = false;
if (assumedValue != null){
assumedCleared = true;
resetNode();
}
observed = true;
if (trueValue == null){
double diceRoll = Math.random();
if (diceRoll <= this.getProbability().doubleValue()){
trueValue = Boolean.TRUE;
} else {
trueValue = Boolean.FALSE;
}
if (assumedCleared){
System.out.println("WARNING: the node had an assumed value which was cleared when observing it, and its value was then randomly generated - the probability used for generating the value may not have been the intended one.");
}
}
changePotentialOfValues(!trueValue, Fraction.ZERO);
// normalizeNodePotential();
probability = null;
}
private void changePotentialOfValues(boolean value, Fraction newpotential){
indexChooser selfChooser = new indexChooser();
ArrayList<Integer> locationsToChange;
if (value){
locationsToChange = selfChooser.getAllIndexes(type, true);
} else {
locationsToChange = selfChooser.getAllIndexes(type, false);
}
for (Integer i : locationsToChange){
potential[i] = newpotential;
}
}
public void observe(boolean observation){
trueValue = observation;
observe();
}
public boolean assumeValue(boolean value){
if (observed){
return false;
}
if (assumedValue != null){
clearAssumedValue();
}
assumedValue = value;
changePotentialOfValues(!value, Fraction.ZERO);
// normalizeNodePotential();
probability = null;
return true;
}
public void clearAssumedValue(){
if (assumedValue != null){
if (!observed){
resetNode();
} else {
assumedValue = null;
}
}
}
public Boolean assumedValue(){
if (assumedValue == null){
return null;
}
if (assumedValue){
return Boolean.TRUE;
} else {
return Boolean.FALSE;
}
}
public Fraction[] getPotential(){
return copyFraction(potential);
}
private Fraction copyFraction(Fraction f){
Fraction newFraction = new Fraction(f.getNumerator(), f.getDenominator());
return newFraction;
}
private Fraction[] copyFraction(Fraction[] f){
Fraction[] newFraction = new Fraction[f.length];
for (int i = 0; i < newFraction.length ; i++){
newFraction[i] = copyFraction(f[i]);
}
return newFraction;
}
public Fraction[] getMarginalPotential(Object o){
Object[] targetScope = {o};
HashMap<Object,Integer> targetStride = this.createStridesFromScope(targetScope);
return marginalizeOut(potential, targetScope, targetStride);
}
public Fraction[] getNormalizedMarginalPotential(Object o){
Fraction[] marginalPotential = this.getMarginalPotential(o);
return this.normalizePotentials(marginalPotential);
}
protected void normalizeNodePotential(){
Fraction total = Fraction.ZERO;
for (Fraction f : potential){
total = total.add(f);
}
if (!total.equals(Fraction.ZERO)){
for (int i = 0; i < potential.length; i++){
Fraction f = potential[i];
potential[i] = f.divide(total);
}
}
for (Fraction f : potential){
}
}
protected void receiveUpstreamMessage(Message message){
upstreamMessages.add(message);
}
protected void receiveDownstreamMessage(Message message){
downstreamMessages.add(message);
}
protected boolean receivedMessageFrom(BayesNode source, boolean upstream){
HashSet<Message> receivedMessages;
if (upstream){
receivedMessages = upstreamMessages;
} else {
receivedMessages = downstreamMessages;
}
return receivedMessages.contains(new Message(source));
}
protected void clearMessages(){
upstreamMessages = new HashSet<Message>();
downstreamMessages = new HashSet<Message>();
}
public Fraction[] normalizePotentials(Fraction[] targetPotential){
Fraction totalSum = Fraction.ZERO;
for (Fraction f : targetPotential){
totalSum = totalSum.add(f);
}
if (totalSum.equals(Fraction.ZERO)){
Fraction f = Fraction.ZERO;
Arrays.fill(targetPotential, f);
System.out.println("Encountered division by zero, check your network");
} else {
for (int i = 0; i < targetPotential.length; i++){
Fraction f = targetPotential[i];
targetPotential[i] = f.divide(totalSum);
}
}
return targetPotential;
}
/**
* Multiplies received messages from the specified direction with the initial
* CPT, sums out any variables not in the specified scope, and packs the result into a
* message. Note that this method does NOT check whether the node has received all the
* prerequisite messages and is thus ready to send - this is the responsibility of the
* calling class. (Individual nodes are not aware of their neighbors, thus cannot check
* their own readiness.)
*
* @param upstream true if the message is sent to be upstream, false if downstream
* @param scope the scope of the message
* @return a message with the specified scope, computed message, and this node as the sender
*/
protected Message generateMessage(boolean upstream, Object... targetScope){
HashSet<Message> receivedMessages;
if (upstream){
receivedMessages = upstreamMessages;
} else {
receivedMessages = downstreamMessages;
}
Fraction[] multipliedPotential;
if (!receivedMessages.isEmpty()){
multipliedPotential = multiplyPotentialWithMessages(this.getPotential(), receivedMessages);
} else {
multipliedPotential = potential;
}
HashMap<Object,Integer> targetStride = this.createStridesFromScope(targetScope);
Fraction[] potentialMessage = marginalizeOut(multipliedPotential, targetScope, targetStride);
Message message = new Message(targetScope, potentialMessage, targetStride, this);
return message;
}
protected Message generateUpstreamMessage(Object... targetScope){
return generateMessage(true, targetScope);
}
protected Message generateDownstreamMessage(Object... targetScope){
return generateMessage(false, targetScope);
}
protected void multiplyPotentialWithMessages(){
probability = null;
if (!upstreamMessages.isEmpty()){
potential = multiplyPotentialWithMessages(potential, upstreamMessages);
}
probability = null;
if (!downstreamMessages.isEmpty()){
potential = multiplyPotentialWithMessages(potential, downstreamMessages);
}
probability = null;
clearMessages();
probability = null;
}
private Fraction[] multiplyPotentialWithMessages(Fraction [] currentPotential, HashSet<Message> receivedMessages){
indexChooser selfChooser = new indexChooser();
Fraction[] newPotential = copyFraction(currentPotential);
for (Message m : receivedMessages){
if (m.scope.length == 1){
Object o = m.scope[0];
selfChooser.requestUntrue(o);
ArrayList<Integer> arrayReferencesToWantedObjectBeingUntrueInOwnPotential = selfChooser.getAllIndexes();
Fraction trueMultiplier = m.message[0];
Fraction untrueMultiplier = m.message[1];
for (int i = 0; i < currentPotential.length; i++){
if (arrayReferencesToWantedObjectBeingUntrueInOwnPotential.contains(i)){
newPotential[i] = newPotential[i].multiply(untrueMultiplier);
} else {
newPotential[i] = newPotential[i].multiply(trueMultiplier);
}
}
} else {
throw new IllegalStateException("Message contained truth values for multiple variables, not yet implemented");
//TODO: implement the case where the message contains the truth values for multiple variables!
}
selfChooser.resetUntrue();
}
return newPotential;
}
private Fraction[] marginalizeOut(Fraction[] currentPotential, Object[] targetScope, HashMap<Object,Integer> targetStride){
Fraction[] newPotential = new Fraction[(int) Math.pow(2, targetScope.length)];
// note that we are assuming that the scope and stride of the potential we're summing out from are the same as for the node in general
indexChooser targetChooser = new indexChooser(targetScope, newPotential, targetStride);
indexChooser selfChooser = new indexChooser(this.scope, currentPotential, this.strides);
// find the variables to be summed out by comparing the current scope and the target one
HashSet<Object> targetSet = new HashSet<Object>(Arrays.asList(targetScope));
HashSet<Object> differenceSet = new HashSet<Object>(Arrays.asList(this.scope));
differenceSet.removeAll(targetSet);
ArrayList<Integer> currentPotentialArrayReferencesToItemsToBeSummedOut = new ArrayList<Integer>();
for (Object o : differenceSet){
for (int i = 0; i < this.scope.length; i++){
if (this.scope[i].equals(o)){
currentPotentialArrayReferencesToItemsToBeSummedOut.add(i);
}
}
}
// sum out the contents of the old array locations to the new ones
for (int i = 0; i < currentPotential.length; i++){
// take the fraction in the current index
Fraction f = currentPotential[i];
// find the array reference to the new array that corresponds to the logical contents of this index, but without
// the variables to be summed out
ArrayList<Boolean> logicalContentsOfIndex = selfChooser.getTruthValues(i);
ArrayList<Boolean> logicalContentsOfTargetIndex = new ArrayList<Boolean>();
for (int j = 0; j < logicalContentsOfIndex.size(); j++){
if (!currentPotentialArrayReferencesToItemsToBeSummedOut.contains(j)){
logicalContentsOfTargetIndex.add(logicalContentsOfIndex.get(j));
}
}
int targetPotentialArrayReference = targetChooser.getIndex(logicalContentsOfTargetIndex);
// add the fraction to that new location
if (newPotential[targetPotentialArrayReference] == null){
newPotential[targetPotentialArrayReference] = f;
} else {
newPotential[targetPotentialArrayReference] = newPotential[targetPotentialArrayReference].add(f);
}
}
return newPotential;
}
public boolean equals(Object other){
boolean result = false;
if (other instanceof BayesNode){
BayesNode theOther = (BayesNode)other;
result = (this.type.equals(theOther.type));
}
return result;
}
public int hashCode(){
return type.hashCode();
}
public String toString(){
return type.toString();
}
public void addProperty(String property){
properties.add(property);
}
public boolean hasProperty(String property){
return properties.contains(property);
}
public void removeProperty(String property){
properties.remove(property);
}
private class indexChooser {
private HashSet<Object> requestedUntrueVariables = new HashSet<Object>();
private final Object[] targetScope;
private final Fraction[] targetFactor;
private final HashMap<Object,Integer> targetStrides;
public indexChooser(Object[] targetScope, Fraction[] targetFactor, HashMap<Object,Integer> targetStrides){
this.targetScope = targetScope;
this.targetFactor = targetFactor;
this.targetStrides = targetStrides;
}
public indexChooser(Message message){
this(message.scope, message.message, message.strides);
}
public indexChooser(){
this(scope, potential, strides);
}
public void requestUntrue(Object o){
requestedUntrueVariables.add(o);
}
public void requestUntrue(Object[] o){
requestedUntrueVariables.addAll(Arrays.asList(o));
}
public void resetUntrue(){
requestedUntrueVariables = new HashSet<Object>();
}
/**
* Returns the index in the CPT/potential array corresponding to the row where
* all of the variables specified via the requestUntrue methods are untrue, and
* all the rest are true.
*
* @return the CPT/potential array index
*/
public int getIndex(){
int location = 0;
for (Object o : requestedUntrueVariables){
if (targetStrides.containsKey(o)){
location = location + targetStrides.get(o);
}
}
return location;
}
public int getIndex(ArrayList<Boolean> logicalValues){
this.resetUntrue();
for (int i = 0; i < logicalValues.size(); i++){
if (!logicalValues.get(i)){
this.requestUntrue(targetScope[i]);
}
}
return this.getIndex();
}
/**
* Returns a list of integers containing every index of the potential array
* where all the variables specified via the requestUntrue methods are untrue.
*
* This is probably a horrible terrible implementation and there's some obvious
* way of making it faster. Right now it loops through the whole potential array,
* checks the logical equivalent of each array index, and then adds it to the
* list of indexes if it matches the criteria.
*
* @return
*/
public ArrayList<Integer> getAllIndexes(){
ArrayList<Integer> allIndexes = new ArrayList<Integer>();
for (int i = 0; i < targetFactor.length; i++){
boolean canBeAdded = true;
ArrayList<Boolean> x = getTruthValues(i);
for (int j = 0; j < x.size(); j++){
if (x.get(j) && requestedUntrueVariables.contains(targetScope[j])){
canBeAdded = false;
}
}
if (canBeAdded){
allIndexes.add(i);
}
}
return allIndexes;
}
/**
* Returns a list of all indices in which the specified object is either true or false.
*
* @param o The object under examination
* @param t Whether to return indices where it's true, or indices where it's false
* @return An ArrayList of indexes
*/
public ArrayList<Integer> getAllIndexes(Object o, boolean t){
ArrayList<Integer> indexes = new ArrayList<Integer>();
int stride = targetStrides.get(o);
for (int location = 0; location < targetFactor.length; location++){
if (variableTruthValue(location, stride) == t) {
indexes.add(location);
}
}
return indexes;
}
public ArrayList<Boolean> getTruthValues(int location){
ArrayList<Boolean> values = new ArrayList<Boolean>();
for (Object o : targetScope){
int stride = targetStrides.get(o);
values.add(variableTruthValue(location, stride));
}
return values;
}
private boolean variableTruthValue(int location, int stride){
// yes, intentionally doing a division with ints here, as I'd want to round down the result anyway
int value = (location / stride) % 2;
if (value == 0){
return true;
} else {
return false;
}
}
}
}