package org.societies.context.user.refinement.impl.bayesianLibrary.inference.solving; import java.util.ArrayList; import java.util.Arrays; import java.util.Hashtable; import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.impl.DAG; import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.impl.Node; import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.impl.Probability; import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.impl.ProbabilityDistribution; import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.impl.StateRV; import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.interfaces.HasProbabilityTable; public class JointProbabilityDistributionSolver { private DAG dagNetwork; private String targetRV; private Node targetRVNode; private ArrayList<HasProbabilityTable> nodes = new ArrayList<HasProbabilityTable>(); private static boolean debug = false; private ArrayList<HasProbabilityTable> observedNodes = new ArrayList<HasProbabilityTable>(); public JointProbabilityDistributionSolver(DAG dag){ this.dagNetwork = dag; for(int i=0;i<this.dagNetwork.getNodes().length;i++) { nodes.add(this.dagNetwork.getNodes()[i]); } } public void initializeStructure(){ Node[] nodes = this.dagNetwork.getNodes(); for(int i=0;i<nodes.length;i++){ nodes[i].initializeObservation(); } } public void setTargetRV(String nameRV){ Node[] nodes = this.dagNetwork.getNodes(); Node targetNode = null; // targetNode will contain the node you want to inference. In our case, "activity" for(Node n:nodes){ if (n.getName().equals(nameRV)){ targetNode=n; break; } } if (targetNode==null){ System.err.println("TargetNode is not existing!"); System.exit(0); } this.targetRVNode = targetNode; this.targetRV = nameRV; } public Node getTargetRVNode(){ return this.targetRVNode; } public String getTargetRV(){ return this.targetRV; } public boolean addEvidence(Node node, String state) { boolean found = false; String[] values = node.getStates(); double[] probs = new double[values.length]; if (debug){ System.out.println("ADD EVIDENCE(node, state) called. Node ="+node.getName()); System.out.println("States: "+node.getStates()); System.out.println("Evidence: "+state); } for (int i=0;i<values.length;i++){ if (state.equals(values[i])){ probs[i]=1; found = true; break; } else probs[i]=0; } if (debug) System.out.println("ADD EVIDENCE(node, state) called. FOUND="+found); if (found) return addEvidence(node, values, probs); else return false; } public boolean addEvidence(Node node, String[] values, double[] perCents) { //Node[] nArray = {node}; ProbabilityDistribution nodeLikelihood = node.setObservation(values, perCents); if (nodeLikelihood == null){ observedNodes.remove(node); return false; } if (!observedNodes.contains(node)) observedNodes.add(node); //Probability[] nodeProbs = nodeLikelihood.getProbabilities(); return true; } public void removeAllEvidence(){ for (Object node: observedNodes) ((Node) node).initializeObservation(); observedNodes = new ArrayList<HasProbabilityTable>(); } public String getJPD(int probabilityPosition){ String stateWanted = null; Node targetNode = null; // targetNode will contain the node you want to inference. In our case, "activity" Node[] nodes = this.dagNetwork.getNodes(); if(observedNodes.size()!=nodes.length-1){ //If not all the nodes except one have observations System.err.println("STILL NOT ALL THE NODES ARE OBSERVED!!"); return null; } for(HasProbabilityTable n:observedNodes){ if (((HasProbabilityTable) n).getName().equals(targetRV)){ System.err.println("TARGET NODE IS OBSERVED!!"); return null; } } for(Node n:nodes){ if (n.getName().equals(targetRV)){ targetNode=n; break; } } /*ALGORITHM: we are going to compute the JPD probability * To do that, we assume all the possible states of the target Random Variable, * and introducing that evidences, we will check the probability tables of all the nodes * in the network and multiplying their values. */ String[] possibleStatesTarget = targetNode.getStates(); double universe; Hashtable<String, Double> universeProbabilityForStates = new Hashtable<String, Double>();/*<States,UniverseProbability>*/ double[] originalOrder = new double[possibleStatesTarget.length]; double[] sorted = new double[possibleStatesTarget.length]; for(int i=0;i<possibleStatesTarget.length;i++){ this.addEvidence(targetNode,possibleStatesTarget[i]); universe = this.computeJPD(); universeProbabilityForStates.put(possibleStatesTarget[i],universe); originalOrder[i] = universe; sorted[i] = universe; this.removeEvidence(targetNode); } Arrays.sort(sorted); if(debug){ for (double d:sorted) System.out.println(d); } // find the probability if (probabilityPosition==0) probabilityPosition = 1; double searchedProb = sorted[sorted.length - probabilityPosition]; int counter = 0; for(int i=0;i<possibleStatesTarget.length;i++){ if(searchedProb==universeProbabilityForStates.get(possibleStatesTarget[i])){ stateWanted = possibleStatesTarget[i]; counter++; } } if (stateWanted==null || counter > 1) System.err.println("error with computing the Universe!!! stateWanted="+ stateWanted + " counter ="+counter); return stateWanted; } public Probability[] getUniverse(StateRV[] statesWanted){ Node[] nodes = this.dagNetwork.getNodes(); boolean stateExists = false; Probability[] probs = new Probability[statesWanted.length]; double[] universesProbabilities = new double[statesWanted.length]; if(observedNodes.size()!=nodes.length-1){ //If not all the nodes except one have observations System.err.println("STILL NOT ALL THE NODES ARE OBSERVED!!"); return null; } for(HasProbabilityTable n:observedNodes){ if (((HasProbabilityTable) n).getName().equals(targetRV)){ System.err.println("TARGET NODE IS OBSERVED!!"); return null; } } for(int j=0;j<statesWanted.length;j++){ for (int i=0;i<this.targetRVNode.getStates().length;i++){// For all the states of the random variable String stateTargetNodeName = targetRVNode.getStates()[i]; if(stateTargetNodeName.equalsIgnoreCase(statesWanted[j].getNameState())){ stateExists = true; break; } } if(!(stateExists)){ System.out.println("The state "+ statesWanted[j] +" is not a state of the target random variable!!!"); return null; } stateExists = false; } /*ALGORITHM: we are going to compute the JPD probability * To do that, we assume all the possible states of the target Random Variable, * and introducing that evidences, we will check the probability tables of all the nodes * in the network and multiplying their values. */ for(int i=0;i<statesWanted.length;i++){ this.addEvidence(targetRVNode,statesWanted[i].getNameState()); universesProbabilities[i] = this.computeJPD(); this.removeEvidence(targetRVNode); } //Now we normalized the probabilities if we want all the states: if(statesWanted.length == this.targetRVNode.getStates().length){ double totalSumProbabilities = 0; for(int i = 0;i<universesProbabilities.length;i++){ totalSumProbabilities += universesProbabilities[i]; } for(int i = 0;i<universesProbabilities.length;i++){ universesProbabilities[i] = universesProbabilities[i]/totalSumProbabilities; } }else{ System.out.println("Not all the states were given!"); } for(int i=0;i<statesWanted.length;i++){ String states[] = new String[1]; states[0] = statesWanted[i].getNameState(); probs[i] = new Probability(states,universesProbabilities[i]); } return probs; } private void removeEvidence(Node node) { node.initializeObservation(); observedNodes.remove(node); } private double computeJPD() { double universe = 1; double universeTemporal = 0; double currentProbValue = 0; double temporal = 0; String[] hardEvidenceState = null; String[] temp = null; /*We dont check that all the nodes are observed because it is already checked in getUniverse method*/ Node[] nodes = this.dagNetwork.getNodes(); Node[] participants; Probability[] conditionalProbabilities = null; for(HasProbabilityTable n:nodes){ participants = ((Node) n).getParticipants(); //We get the node and its parents hardEvidenceState = new String[participants.length]; for(int k=0;k<participants.length;k++){ //For all of them, we are going to collect the hard evidence, their current states hardEvidenceState[k] = participants[k].getHardEvidence(); } conditionalProbabilities = n.getProbTable().getProbabilities(); //We have the conditional probability table in a vector for(int i=0;i<conditionalProbabilities.length;i++){//Now, we look for that parent configuration and that node state temp = conditionalProbabilities[i].getStates(); if(temp.length!=hardEvidenceState.length){ System.out.println("SOMETHING IS WRONG!! SEE COMPUTE UNIVERSE METHOD..."); } if(Arrays.equals(temp,hardEvidenceState)){ //We got the parent configuration and the node state value! :D currentProbValue = conditionalProbabilities[i].getProbability(); break; } } temporal = Math.log10(currentProbValue); universeTemporal += temporal; //We add the log of the conditional probabilities given the evidences //universe *= currentProbValue; //We multiply all the conditional probabilities given the evidences of the network of all the nodes of the network } universe = Math.pow(10, universeTemporal); return universe; } }