/*******************************************************************************
* Copyright (C) 2011-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Vector;
import probcog.srl.directed.DecisionNode;
import probcog.srl.directed.RelationalNode;
import probcog.srl.directed.bln.AbstractGroundBLN;
import probcog.srl.directed.bln.BayesianLogicNetwork;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.CPT;
import edu.ksu.cis.bnj.ver3.core.Domain;
import edu.ksu.cis.bnj.ver3.core.Value;
/*
* Created on Sep 28, 2011
*/
public class BLNprintCPT {
public static class Options {
public Integer firstDataCol = null;
public Integer lastDataCol = null;
public Integer decimals = 2;
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options options = new Options();
int i;
for(i = 0; i < args.length; i++) {
if(!args[i].startsWith("-"))
break;
if(args[i].equals("-firstCol"))
options.firstDataCol = Integer.parseInt(args[++i]);
if(args[i].equals("-lastCol"))
options.lastDataCol = Integer.parseInt(args[++i]);
if(args[i].equals("-decimals"))
options.decimals = Integer.parseInt(args[++i]);
}
if(args.length != i+3) {
System.out.println("\nBLNprintCPTs -- format CPTs for printing using LaTeX\n\n");
System.out.println("\nusage: BLNprintCPT [options] <bln declarations file> <bln fragment network> <node name>\n\n");
System.out.println(" options: -firstCol N first data column to print (1-based index)\n" +
" -lastCol N last data column to print (1-based index), followed by dots\n" +
" -decimals N number of decimals for parameter output (default: 2)\n");
return;
}
String declsFile = args[i];
String fragmentsFile = args[i+1];
String nodeName = args[i+2];
BLNprintCPT printer = new BLNprintCPT(declsFile, fragmentsFile, options);
printer.writeCPT(nodeName);
}
protected BayesianLogicNetwork bln;
protected Options options;
public BLNprintCPT(String declsFile, String fragmentsFile, Options options) throws Exception {
bln = new BayesianLogicNetwork(declsFile, fragmentsFile);
this.options = options;
}
public void writeCPT(String nodeName) throws FileNotFoundException {
int i = 0;
for(BeliefNode n : bln.getNodes()) {
if(n.getName().equals(nodeName)) {
String filename = String.format("cpt-%s-%d.tex", nodeName, i++);
System.out.printf("writing %s...\n", filename);
File f = new File(filename);
writeCPT(n, new PrintStream(f));
}
}
}
public void writeCPT(BeliefNode node, PrintStream out) {
// construct no CPF where decision and precondition parents are clamped to true
RelationalNode rn = bln.getRelationalNode(node);
HashMap<BeliefNode,Integer> constantSettings = new HashMap<BeliefNode,Integer>();
HashSet<BeliefNode> excluded = new HashSet<BeliefNode>();
for(DecisionNode parent : rn.getDecisionParents()) {
constantSettings.put(parent.node, 0);
excluded.add(parent.node);
}
for(RelationalNode parent : rn.getRelationalParents()) {
if(parent.isPrecondition) {
constantSettings.put(parent.node, 0);
excluded.add(parent.node);
}
}
Value[] values = AbstractGroundBLN.getSubCPFValues(node.getCPF(), constantSettings);
Vector<BeliefNode> included = new Vector<BeliefNode>();
CPT cpf = (CPT)node.getCPF();
BeliefNode[] originalDomProd = cpf.getDomainProduct();
for(BeliefNode n : originalDomProd)
if(!excluded.contains(n))
included.add(n);
BeliefNode[] domprod = included.toArray(new BeliefNode[included.size()]);
cpf = new CPT(domprod);
cpf.setValues(values);
Table table = new Table(cpf, this.options);
table.writeLatex(out);
}
public static class Table {
String[][] table;
CPF cpf;
BeliefNode[] domprod;
int numParents;
int currentColumn = 1;
Options options;
public Table(CPF cpf, Options options) {
this.options = options;
this.cpf = cpf;
domprod = cpf.getDomainProduct();
Domain dom = domprod[0].getDomain();
int domSize = dom.getOrder();
int numDistributions = cpf.size() / domSize;
int numColumns = numDistributions+1;
numParents = domprod.length-1;
int numRows = numParents + domSize;
table = new String[numRows][numColumns];
// write parent names
for(int i = 1; i < domprod.length; i++)
table[i-1][0] = domprod[i].getName();
// write domain
for(int i = 0; i < domSize; i++)
table[numParents+i][0] = dom.getName(i);
int[] addr = new int[domprod.length];
writeData(1, addr);
}
protected void writeData(int i, int[] addr) {
String numberFormat = String.format("%%.%df", options.decimals);
Domain dom;
if(i == addr.length) {
// write parent configuration
for(int j = 1; j < domprod.length; j++)
table[j-1][currentColumn] = domprod[j].getDomain().getName(addr[j]);
// write probabilities
dom = domprod[0].getDomain();
int row = domprod.length-1;
for(int j = 0; j < dom.getOrder(); j++) {
addr[0] = j;
double p = cpf.getDouble(addr);
table[row++][currentColumn] = String.format(numberFormat, p);
}
++currentColumn;
return;
}
dom = domprod[i].getDomain();
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
writeData(i+1, addr);
}
}
public void writeLatex(PrintStream out) {
out.println("\\documentclass{letter}\n\\usepackage[a0paper,landscape]{geometry}\n\\pagestyle{empty}\n\\begin{document}");
out.print("\\begin{tabular}{|l|");
printn("l", table[0].length-1, out);
out.print("|}\n\\hline\n");
for(int row = 0; row < table.length; row++) {
for(int col = 0; col < table[row].length; col++) {
boolean isDotsCol = false, end = false;
if(col > 0 && options.firstDataCol != null && col < options.firstDataCol) {
col = options.firstDataCol;
}
if(options.lastDataCol != null && options.lastDataCol == col-1) {
isDotsCol = true;
end = true;
}
if(col > 0) out.print(" & ");
String field = isDotsCol ? "\\dots" : toLatex(table[row][col]);
out.print(field == null ? "" : field);
if(end) break;
}
out.println("\\\\");
if(row+1 == numParents)
out.println("\\hline");
}
out.println("\\hline\\end{tabular}\n\\end{document}");
}
}
public static void printn(String s, int n, PrintStream out) {
for(int i = 0; i < n; i++)
out.print(s);
}
public static String toLatex(String s) {
return s.replace("_", "\\_").replace("#", "");
}
}