import InferenceGraphs.*; import InterchangeFormat.*; import java.io.*; import java.util.ArrayList; import java.util.HashMap; import java.util.Vector; public class a4{ private static class GibbsSampler { HashMap<String, String> allValuesMap; HashMap<String, String> nonEvidenceNodesMap; HashMap<String, HashMap<String, Integer>> requiredVariablesCount; HashMap<String, HashMap<String, Double>> normalizedVariableValues; public final String TotalValue = "TotalValue"; HashMap<String, InferenceGraphNode> allNodesMap; public long startTimeMillis = 0; public long endTimeMillis = 0; /** * initialize stuff * * non-evidence nodes * * counter for the value seen for each non-evidence node */ GibbsSampler(Object[] allNodes, HashMap<String, String> evidenceNodes, String[] requiredVariables) { allNodesMap = new HashMap<String, InferenceGraphNode> (); allValuesMap = new HashMap<String, String>(); nonEvidenceNodesMap = new HashMap<String, String>(); // create a hashmap of all nodes & intialize each node to have a zero value for (int i = 0; i < allNodes.length; i++) { InferenceGraphNode currNode = (InferenceGraphNode)allNodes[i]; String currentNodeName = currNode.get_name(); assert (!allNodesMap.containsKey(currNode.get_name())); allNodesMap.put(currentNodeName, currNode); // set values to default as the first node value String defaultNodeValue = currNode.get_values()[0]; allValuesMap.put(currentNodeName, defaultNodeValue); if (!evidenceNodes.containsKey(currentNodeName)) { nonEvidenceNodesMap.put(currentNodeName, defaultNodeValue); } else { // the evidence nodes are fixed, so we ensure that we use // the given values allValuesMap.put(currentNodeName, evidenceNodes.get(currentNodeName)); } } requiredVariablesCount = new HashMap<String, HashMap<String, Integer>> (); normalizedVariableValues = new HashMap<String, HashMap<String, Double>> (); for (int i = 0; i < requiredVariables.length; i++) { InferenceGraphNode currNode = allNodesMap.get(requiredVariables[i]); String[] values = currNode.get_values(); HashMap<String, Integer> counts = new HashMap<String, Integer>(); HashMap<String, Double> normalizedCount = new HashMap<String, Double>(); for (int j = 0; j < values.length; j++) { counts.put(values[j], 0); normalizedCount.put(values[j], 0.0); } counts.put(TotalValue, 0); requiredVariablesCount.put(currNode.get_name(), counts); normalizedVariableValues.put(currNode.get_name(), normalizedCount); } } public void generateGibbsSamples(int N, int samplesToSkip) { startTimeMillis = System.currentTimeMillis(); int lastSamplesCount = 0; for (int i = 1; i <= N; i++) { for (String s : nonEvidenceNodesMap.keySet()) { String result = SampleGivenMarkovBlanket(s); allValuesMap.put(s, result); if (requiredVariablesCount.containsKey(s)) { Integer currCount = requiredVariablesCount.get(s).get(result); Integer total = requiredVariablesCount.get(s).get(this.TotalValue); currCount = currCount + 1; total = total + 1; requiredVariablesCount.get(s).put(result, currCount); requiredVariablesCount.get(s).put(this.TotalValue, total); } } if (i % samplesToSkip == 0) { endTimeMillis = System.currentTimeMillis(); double samplesPerSec = (i - lastSamplesCount + 0.0) / ((0.0 + endTimeMillis - startTimeMillis)/1000.0); startTimeMillis = endTimeMillis; lastSamplesCount = i; System.out.println("So far:" + i + " at " + samplesPerSec + " per second"); // let's record the probability values NormalizeVariableCounts(); printNormalizedValues(); if (i == 10000) { ClearVariableCounts(); } } } } private String SampleGivenMarkovBlanket(String s) { InferenceGraphNode currNode = allNodesMap.get(s); // now calculate possible probabilities HashMap<String, Double> probabilityMap = new HashMap<String, Double>(); Double rand = Math.random(); String valueToUse = null; Double probabilitySum = 0.0; Double currProbSum = 0.0; for (int i = 0; i < currNode.get_values().length; i++) { Vector<?> childNodes = currNode.get_children(); allValuesMap.put(currNode.get_name(), currNode.get_values()[i]); double logProbabilityChildren = 0; for(Object o : childNodes) { InferenceGraphNode nc = (InferenceGraphNode) o; logProbabilityChildren = logProbabilityChildren + Math.log(calculateProbabilityOfNode(nc, allValuesMap.get(nc.get_name()))); } String currValue = currNode.get_values()[i]; Double currProbValue = calculateProbabilityOfNode(currNode, currValue); currProbSum = currProbSum + currProbValue; Double currProbLog = Math.log(currProbValue); Double expectedProb = Math.exp(currProbLog + logProbabilityChildren); probabilitySum += expectedProb; probabilityMap.put(currValue, expectedProb); } Double samplePoint = rand * probabilitySum; Double tempSum = 0.0; for (String v : probabilityMap.keySet()) { Double currProb = probabilityMap.get(v); tempSum = tempSum + currProb; if (tempSum >= samplePoint) { valueToUse = v; break; } } if (valueToUse == null) { System.out.println("valueToUse is null!!!"); } return valueToUse; //return firstValue; } private double calculateProbabilityOfNode(InferenceGraphNode currNode, String currNodeValue) { Vector<?> parentNodes = currNode.get_parents(); String [][] vars = new String[parentNodes.size() + 1][2]; vars[0][0] = currNode.get_name(); vars[0][1] = allValuesMap.get(currNode.get_name()); int i = 1; for (Object o : parentNodes) { InferenceGraphNode n = (InferenceGraphNode) o; vars[i][0] = n.get_name(); vars[i][1] = allValuesMap.get(n.get_name()); i++; } return currNode.get_function_value(vars); } private void printNormalizedValues() { StringBuilder sb = new StringBuilder(); for (String s : this.normalizedVariableValues.keySet()) { HashMap<String, Double> probMap = this.normalizedVariableValues.get(s); sb.append(String.format("\"%s\" = {", s)); for (String valueNames : allNodesMap.get(s).get_values()) { assert (probMap.get(valueNames)!=null); sb.append(String.format(" %s ", probMap.get(valueNames))); } sb.append("}\n"); } System.out.println(sb.toString()); } private void NormalizeVariableCounts() { for (String s : this.requiredVariablesCount.keySet()) { for (String values: this.requiredVariablesCount.get(s).keySet()) { if (!values.equalsIgnoreCase(this.TotalValue)) { double prob = this.requiredVariablesCount.get(s).get(values) / (1.0 * this.requiredVariablesCount.get(s).get(this.TotalValue)); normalizedVariableValues.get(s).put(values, prob); } } } } private void ClearVariableCounts() { for (String s : this.requiredVariablesCount.keySet()) { for (String values: this.requiredVariablesCount.get(s).keySet()) { this.requiredVariablesCount.get(s).put(values, 0); } } } } public static void main(String[] args){ try{ InferenceGraph G=new InferenceGraph("alarm.bif"); Vector<?> nodes= G.get_nodes(); InferenceGraphNode n=((InferenceGraphNode)nodes.elementAt(0)); System.out.println(n.get_name()); n.get_Prob().print(); ArrayList<InferenceGraphNode> diagnosticNodes = new ArrayList<InferenceGraphNode>(); // build diagnostic nodes list for (Object d : nodes) { InferenceGraphNode node = (InferenceGraphNode) d; if (node.get_parents().isEmpty()) { diagnosticNodes.add(node); System.out.print("\"" + node.get_name() + "\","); } } String[] diagnosticArray = {"Hypovolemia","LVFailure","Anaphylaxis","InsuffAnesth", "PulmEmbolus","Intubation","Disconnect","KinkedTube"}; BufferedReader inputFile = new BufferedReader (new FileReader("input.txt")); String line; HashMap<String, String> evidenceVariables = new HashMap<String, String>(); line = inputFile.readLine(); do { String[] lineSplit = line.split("="); String var = lineSplit[0].trim().substring(1, lineSplit[0].length() - 1); String varValue = lineSplit[1].trim().substring(1, lineSplit[1].length() - 1); evidenceVariables.put(var, varValue); line = inputFile.readLine(); } while (line != null && !line.contains(".")); inputFile.close(); GibbsSampler gc = new GibbsSampler(nodes.toArray(), evidenceVariables, diagnosticArray); gc.generateGibbsSamples(1000 * 1000, 1000); } catch(IFException e){ System.out.println("Formatting Incorrect "+e.toString()); } catch(FileNotFoundException e) { System.out.println("File not found "+e.toString()); } catch(IOException e){ System.out.println("File not found "+e.toString()); } } }