/*******************************************************************************
* Copyright (C) 2011-2012 Dominik Jain, Paul Maier.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.io.BufferedInputStream;
import java.io.File;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import probcog.bayesnets.conversion.BNDB2Inst;
import probcog.bayesnets.core.BNDatabase;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.tum.cs.util.FileUtil;
/**
* A simple wrapper for the ACE2.0 inference engine (arithmetic circuits evaluation).
* @author Dominik Jain
*/
public class ACE extends Sampler {
protected File acePath = null;
protected File bnFile, instFile;
private String aceParams = "";
protected double compileTime, evalTime;
public ACE(BeliefNetworkEx bn) throws Exception {
super(bn);
paramHandler.add("acePath", "setAcePath");
paramHandler.add("aceParams", "setAceParams");
}
public void setAceParams(String aceParams) {
this.aceParams = aceParams;
}
public void setAcePath(String path) throws Exception {
this.acePath = new File(path);
if(!acePath.exists() || !acePath.isDirectory())
throw new Exception("The given path " + path + " does not exist or is not a directory");
}
protected BufferedInputStream runAce(String command, String params) throws Exception {
String[] aParams = params.trim().split("\\s+");
String[] cmd = new String[aParams.length+1];
for(int i = 0; i < aParams.length; i++)
cmd[i+1] = aParams[i];
File cmdFile = new File(acePath + File.separator + command);
if(!cmdFile.exists()) {
cmdFile = new File(acePath + File.separator + command + ".bat");
if(!cmdFile.exists())
throw new Exception("Could not find " + command + " (or .bat) in " + acePath);
}
cmd[0] = cmdFile.toString();
System.out.println(" " + Arrays.toString(cmd));
Process p = Runtime.getRuntime().exec(cmd);
BufferedInputStream is = new BufferedInputStream(p.getInputStream());
p.waitFor();
String error = FileUtil.readInputStreamAsString(p.getErrorStream());
if(!error.isEmpty())
throw new Exception("Error running ACE: " + error);
return is;
}
protected void _initialize() throws Exception {
if(acePath == null)
throw new Exception("No ACE 2.0 path was given. This inference method requires ACE2.0 and the location at which it is installed must be configured");
// save belief network as .xbif
bnFile = new File("ace.tmp.xbif");
this.bn.save(bnFile.getPath());
// compile arithmetic circuit using ace compiler
if(verbose) System.out.println("compiling arithmetic circuit...");
if(verbose && !aceParams.isEmpty()) System.out.println(" ACE params: " + this.aceParams);
BufferedInputStream is = runAce("compile", this.aceParams + " " + bnFile.getName());
String compileOutput = FileUtil.readInputStreamAsString(is);
if(debug)
System.out.println(compileOutput);
Pattern p = Pattern.compile("(?:Compile|Complie) Time \\(s\\) : (.*?)$", Pattern.MULTILINE);
Matcher m = p.matcher(compileOutput);
if(m.find()) {
compileTime = parseDouble(m.group(1));
report(String.format("ACE compile time: %ss", compileTime));
}
// write evidence to .inst file
instFile = new File("ace.tmp.inst");
BNDB2Inst.convert(new BNDatabase(this.bn, this.evidenceDomainIndices), instFile);
}
@Override
protected void _infer() throws Exception {
// run Ace inference
if(verbose) System.out.println("evaluating...");
BufferedInputStream is = runAce("evaluate", bnFile.getName() + " " + instFile.getName());
//NumberFormat format = NumberFormat.getInstance();
// read running time
String output = FileUtil.readInputStreamAsString(is);
Pattern p = Pattern.compile("Total Inference Time \\(ms\\) : (\\d+)", Pattern.MULTILINE);
Matcher m = p.matcher(output);
if(m.find()) {
evalTime = parseDouble(m.group(1))/*format.parse(m.group(1)).doubleValue()*/ / 1000.0;
report(String.format("ACE evaluation time: %ss", evalTime));
}
// create output distribution
SampledDistribution dist = createDistribution();
File marginalsFile = new File(bnFile.getName() + ".marginals");
if(verbose) System.out.println("reading results...");
String results = FileUtil.readTextFile(marginalsFile);
if(debug)
System.out.println(results);
String patFloat = "(?:\\d+([\\.,]\\d+)?(?:E[-\\d]+)?)";
// * get probability of the evidence
Pattern probEvid = Pattern.compile(String.format("p \\(e\\) = (%s)", patFloat));
m = probEvid.matcher(results);
if(!m.find())
throw new Exception("Could not find 'p (e)' in results");
if(m.group(1).equals("0E0"))
throw new Exception("The probability of the evidence is 0");
Number numPE = parseDouble(m.group(1)); //format.parse(m.group(1));
dist.Z = numPE.doubleValue();
System.out.println("probability of the evidence: " + dist.Z);
// * get posteriors
Pattern patMarginal = Pattern.compile(String.format("p \\((.*?) \\| e\\) = \\[(%s(?:, %s)+)\\]", patFloat, patFloat));
m = patMarginal.matcher(results);
int cnt = 0;
while(m.find()) {
String varName = m.group(1);
String[] v = m.group(2).split(", ");
int nodeIdx = this.getNodeIndex(bn.getNode(varName));
if(v.length != dist.values[nodeIdx].length)
throw new Exception("Marginal vector length for '" + varName + "' incorrect");
for(int i = 0; i < v.length; i++)
// here, it doesn't use the locale, always a '.' in there
dist.values[nodeIdx][i] = parseDouble(v[i]); //format.parse(v[i]).doubleValue();
cnt++;
}
System.out.println(cnt + " marginals read");
((ImmediateDistributionBuilder)distributionBuilder).setDistribution(dist);
// clean up
new File(bnFile.getName() + ".ac").delete();
new File(bnFile.getName() + ".lmap").delete();
bnFile.delete();
instFile.delete();
marginalsFile.delete();
}
public double getAceCompileTime() {
return compileTime;
}
public double getAceEvalTime() {
return evalTime;
}
protected static double parseDouble(String s) {
return Double.parseDouble(s.replace(',', '.'));
}
protected IDistributionBuilder createDistributionBuilder() {
return new ImmediateDistributionBuilder();
}
}