import InferenceGraphs.*;
import InterchangeFormat.*;
import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Vector;
public class a4{
private static class GibbsSampler
{
HashMap<String, String> allValuesMap;
HashMap<String, String> nonEvidenceNodesMap;
HashMap<String, HashMap<String, Integer>> requiredVariablesCount;
HashMap<String, HashMap<String, Double>> normalizedVariableValues;
public final String TotalValue = "TotalValue";
HashMap<String, InferenceGraphNode> allNodesMap;
public long startTimeMillis = 0;
public long endTimeMillis = 0;
/**
* initialize stuff
* * non-evidence nodes
* * counter for the value seen for each non-evidence node
*/
GibbsSampler(Object[] allNodes, HashMap<String, String> evidenceNodes,
String[] requiredVariables)
{
allNodesMap = new HashMap<String, InferenceGraphNode> ();
allValuesMap = new HashMap<String, String>();
nonEvidenceNodesMap = new HashMap<String, String>();
// create a hashmap of all nodes & intialize each node to have a zero value
for (int i = 0; i < allNodes.length; i++)
{
InferenceGraphNode currNode = (InferenceGraphNode)allNodes[i];
String currentNodeName = currNode.get_name();
assert (!allNodesMap.containsKey(currNode.get_name()));
allNodesMap.put(currentNodeName, currNode);
// set values to default as the first node value
String defaultNodeValue = currNode.get_values()[0];
allValuesMap.put(currentNodeName, defaultNodeValue);
if (!evidenceNodes.containsKey(currentNodeName))
{
nonEvidenceNodesMap.put(currentNodeName, defaultNodeValue);
}
else
{
// the evidence nodes are fixed, so we ensure that we use
// the given values
allValuesMap.put(currentNodeName, evidenceNodes.get(currentNodeName));
}
}
requiredVariablesCount = new HashMap<String, HashMap<String, Integer>> ();
normalizedVariableValues = new HashMap<String, HashMap<String, Double>> ();
for (int i = 0; i < requiredVariables.length; i++)
{
InferenceGraphNode currNode = allNodesMap.get(requiredVariables[i]);
String[] values = currNode.get_values();
HashMap<String, Integer> counts = new HashMap<String, Integer>();
HashMap<String, Double> normalizedCount = new HashMap<String, Double>();
for (int j = 0; j < values.length; j++)
{
counts.put(values[j], 0);
normalizedCount.put(values[j], 0.0);
}
counts.put(TotalValue, 0);
requiredVariablesCount.put(currNode.get_name(), counts);
normalizedVariableValues.put(currNode.get_name(), normalizedCount);
}
}
public void generateGibbsSamples(int N, int samplesToSkip)
{
startTimeMillis = System.currentTimeMillis();
int lastSamplesCount = 0;
for (int i = 1; i <= N; i++)
{
for (String s : nonEvidenceNodesMap.keySet())
{
String result = SampleGivenMarkovBlanket(s);
allValuesMap.put(s, result);
if (requiredVariablesCount.containsKey(s))
{
Integer currCount = requiredVariablesCount.get(s).get(result);
Integer total = requiredVariablesCount.get(s).get(this.TotalValue);
currCount = currCount + 1;
total = total + 1;
requiredVariablesCount.get(s).put(result, currCount);
requiredVariablesCount.get(s).put(this.TotalValue, total);
}
}
if (i % samplesToSkip == 0)
{
endTimeMillis = System.currentTimeMillis();
double samplesPerSec = (i - lastSamplesCount + 0.0) / ((0.0 + endTimeMillis - startTimeMillis)/1000.0);
startTimeMillis = endTimeMillis;
lastSamplesCount = i;
System.out.println("So far:" + i + " at " + samplesPerSec + " per second");
// let's record the probability values
NormalizeVariableCounts();
printNormalizedValues();
if (i == 10000)
{
ClearVariableCounts();
}
}
}
}
private String SampleGivenMarkovBlanket(String s) {
InferenceGraphNode currNode = allNodesMap.get(s);
// now calculate possible probabilities
HashMap<String, Double> probabilityMap = new HashMap<String, Double>();
Double rand = Math.random();
String valueToUse = null;
Double probabilitySum = 0.0;
Double currProbSum = 0.0;
for (int i = 0; i < currNode.get_values().length; i++)
{
Vector<?> childNodes = currNode.get_children();
allValuesMap.put(currNode.get_name(), currNode.get_values()[i]);
double logProbabilityChildren = 0;
for(Object o : childNodes)
{
InferenceGraphNode nc = (InferenceGraphNode) o;
logProbabilityChildren = logProbabilityChildren +
Math.log(calculateProbabilityOfNode(nc, allValuesMap.get(nc.get_name())));
}
String currValue = currNode.get_values()[i];
Double currProbValue = calculateProbabilityOfNode(currNode, currValue);
currProbSum = currProbSum + currProbValue;
Double currProbLog = Math.log(currProbValue);
Double expectedProb = Math.exp(currProbLog + logProbabilityChildren);
probabilitySum += expectedProb;
probabilityMap.put(currValue, expectedProb);
}
Double samplePoint = rand * probabilitySum;
Double tempSum = 0.0;
for (String v : probabilityMap.keySet())
{
Double currProb = probabilityMap.get(v);
tempSum = tempSum + currProb;
if (tempSum >= samplePoint)
{
valueToUse = v;
break;
}
}
if (valueToUse == null)
{
System.out.println("valueToUse is null!!!");
}
return valueToUse;
//return firstValue;
}
private double calculateProbabilityOfNode(InferenceGraphNode currNode, String currNodeValue)
{
Vector<?> parentNodes = currNode.get_parents();
String [][] vars = new String[parentNodes.size() + 1][2];
vars[0][0] = currNode.get_name();
vars[0][1] = allValuesMap.get(currNode.get_name());
int i = 1;
for (Object o : parentNodes)
{
InferenceGraphNode n = (InferenceGraphNode) o;
vars[i][0] = n.get_name();
vars[i][1] = allValuesMap.get(n.get_name());
i++;
}
return currNode.get_function_value(vars);
}
private void printNormalizedValues() {
StringBuilder sb = new StringBuilder();
for (String s : this.normalizedVariableValues.keySet())
{
HashMap<String, Double> probMap = this.normalizedVariableValues.get(s);
sb.append(String.format("\"%s\" = {", s));
for (String valueNames : allNodesMap.get(s).get_values())
{
assert (probMap.get(valueNames)!=null);
sb.append(String.format(" %s ", probMap.get(valueNames)));
}
sb.append("}\n");
}
System.out.println(sb.toString());
}
private void NormalizeVariableCounts() {
for (String s : this.requiredVariablesCount.keySet())
{
for (String values: this.requiredVariablesCount.get(s).keySet())
{
if (!values.equalsIgnoreCase(this.TotalValue))
{
double prob = this.requiredVariablesCount.get(s).get(values) /
(1.0 * this.requiredVariablesCount.get(s).get(this.TotalValue));
normalizedVariableValues.get(s).put(values, prob);
}
}
}
}
private void ClearVariableCounts() {
for (String s : this.requiredVariablesCount.keySet())
{
for (String values: this.requiredVariablesCount.get(s).keySet())
{
this.requiredVariablesCount.get(s).put(values, 0);
}
}
}
}
public static void main(String[] args){
try{
InferenceGraph G=new InferenceGraph("alarm.bif");
Vector<?> nodes= G.get_nodes();
InferenceGraphNode n=((InferenceGraphNode)nodes.elementAt(0));
System.out.println(n.get_name());
n.get_Prob().print();
ArrayList<InferenceGraphNode> diagnosticNodes = new ArrayList<InferenceGraphNode>();
// build diagnostic nodes list
for (Object d : nodes)
{
InferenceGraphNode node = (InferenceGraphNode) d;
if (node.get_parents().isEmpty())
{
diagnosticNodes.add(node);
System.out.print("\"" + node.get_name() + "\",");
}
}
String[] diagnosticArray = {"Hypovolemia","LVFailure","Anaphylaxis","InsuffAnesth",
"PulmEmbolus","Intubation","Disconnect","KinkedTube"};
BufferedReader inputFile = new BufferedReader (new FileReader("input.txt"));
String line;
HashMap<String, String> evidenceVariables = new HashMap<String, String>();
line = inputFile.readLine();
do
{
String[] lineSplit = line.split("=");
String var = lineSplit[0].trim().substring(1, lineSplit[0].length() - 1);
String varValue = lineSplit[1].trim().substring(1, lineSplit[1].length() - 1);
evidenceVariables.put(var, varValue);
line = inputFile.readLine();
} while (line != null && !line.contains("."));
inputFile.close();
GibbsSampler gc = new GibbsSampler(nodes.toArray(),
evidenceVariables, diagnosticArray);
gc.generateGibbsSamples(1000 * 1000, 1000);
}
catch(IFException e){
System.out.println("Formatting Incorrect "+e.toString());
}
catch(FileNotFoundException e) {
System.out.println("File not found "+e.toString());
}
catch(IOException e){
System.out.println("File not found "+e.toString());
}
}
}