/*******************************************************************************
* Copyright (C) 2009-2012 Dominik Jain, Paul Maier.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.core.io;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintStream;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.ksu.cis.bnj.ver3.core.BeliefNetwork;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.streams.Exporter;
import edu.ksu.cis.bnj.ver3.streams.OmniFormatV1;
/**
* Importer for the Ergo file format (http://graphmod.ics.uci.edu/group/Ergo_file_format)
* @author Dominik Jain
*/
public class Converter_ergo implements edu.ksu.cis.bnj.ver3.streams.Importer, Exporter {
protected boolean isUAIstyle = false;
public String getDesc() {
return "Ergo";
}
public String getExt() {
return "*.erg";
}
public void load(InputStream stream, OmniFormatV1 writer) {
BufferedReader br = new BufferedReader(new InputStreamReader(stream));
String line;
try {
// read preamble
int numVars = readLineOfInts(br)[0];
int[] domSizes = readLineOfInts(br);
int[][] parents = new int[numVars][];
for(int i = 0; i < numVars; i++) {
parents[i] = readLineOfInts(br);
}
// read probability tables
line = nextLine(br);
if(!line.contains("Probabilities"))
throw new IOException("Expected 'Probabilities' section, got this: " + line);
double[][] cpfs = new double[numVars][];
for(int i = 0; i < numVars; i++) {
int numEntries = readLineOfInts(br)[0];
int parentConfigs = 1;
for(int j = 1; j < parents[i].length; j++)
parentConfigs *= domSizes[parents[i][j]];
int entriesPerLine = numEntries / parentConfigs;
double[] cpf = new double[numEntries];
for(int j = 0; j < parentConfigs; j++) {
if(readCPF(br, cpf, j*entriesPerLine) != entriesPerLine)
throw new IOException("CPF line contained unexpected number of entries");
}
cpfs[i] = cpf;
}
// read variable names
line = nextLine(br);
if(!line.contains("Names"))
throw new IOException("Expected 'Names' section, got this: " + line);
String[] names = new String[numVars];
for(int i = 0; i < numVars; i++) {
names[i] = nextLine(br);
}
// read domain names
line = nextLine(br);
if(!line.contains("Labels"))
throw new IOException("Expected 'Labels' section, got this: " + line);
String[][] outcomes = new String[numVars][];
for(int i = 0; i < numVars; i++) {
outcomes[i] = readLineOfStrings(br);
if(outcomes[i].length != domSizes[i])
throw new IOException(String.format("Unexpected domain size: Got %d labels but domain size is %d for variable %s", outcomes[i].length, domSizes[i], names[i]));
}
// build the network
writer.Start();
writer.CreateBeliefNetwork(0);
// basic belief node data
for(int i = 0; i < numVars; i++) {
writer.BeginBeliefNode(i);
writer.SetType("chance");
for(int j = 0; j < outcomes[i].length; j++)
writer.BeliefNodeOutcome(outcomes[i][j]);
writer.SetBeliefNodeName(names[i]);
writer.EndBeliefNode();
}
// rest
for(int i = 0; i < numVars; i++) {
// connect parents
for(int j = 1; j < parents[i].length; j++)
writer.Connect(parents[i][j], i);
// cpf
writer.BeginCPF(i);
for(int j = 0; j < cpfs[i].length; j++)
writer.ForwardFlat_CPFWriteValue(Double.toString(cpfs[i][j]));
writer.EndCPF();
}
writer.Finish();
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
protected String nextLine(BufferedReader br) throws IOException {
String l;
do
l = br.readLine().trim();
while(l.length() == 0);
return l;
}
protected String[] readLineOfStrings(BufferedReader br) throws IOException {
return nextLine(br).split("\\s+");
}
protected int[] readLineOfInts(BufferedReader br) throws IOException {
String[] elems = readLineOfStrings(br);
int[] ret = new int[elems.length];
for(int i = 0; i < elems.length; i++)
ret[i] = Integer.parseInt(elems[i]);
return ret;
}
protected int readCPF(BufferedReader br, double[] cpf, int i) throws IOException {
String l = nextLine(br);
String[] elems = l.split("\\s+");
for(int j = 0; j < elems.length; j++)
cpf[i++] = Double.parseDouble(elems[j]);
return elems.length;
}
@Override
public void save(BeliefNetwork bn, OutputStream os) {
BeliefNetworkEx bnex = new BeliefNetworkEx(bn);
PrintStream out = new PrintStream(os);
if(isUAIstyle) out.println("BAYES");
// number of nodes
BeliefNode[] nodes = bn.getNodes();
out.println(nodes.length);
// domain sizes
for(int i = 0; i < nodes.length; i++)
out.printf("%d ", nodes[i].getDomain().getOrder());
out.println();
// parents
for(BeliefNode n : nodes) {
BeliefNode[] domprod = n.getCPF().getDomainProduct();
out.printf("%d", domprod.length-1);
for(int i = 1; i < domprod.length; i++)
out.printf("\t%d", bnex.getNodeIndex(domprod[i]));
out.println();
}
// CPTs
if(!isUAIstyle)
out.println("\n/* Probabilities */");
else
out.println();
for(BeliefNode n : nodes) {
CPF cpf = n.getCPF();
out.println(n.getCPF().size());
writeTable(out, cpf, 1, new int[cpf.getDomainProduct().length]);
out.println();
}
if(!isUAIstyle) {
// variable names
if(!isUAIstyle) out.println("\n/* Names */");
for(BeliefNode n : nodes)
out.println(n.getName());
// domain entry names
out.println("\n/* Labels */");
for(BeliefNode n : nodes) {
int order = n.getDomain().getOrder();
for(int i = 0; i < order; i++) {
out.printf("%s ", n.getDomain().getName(i));
}
out.println();
}
}
}
public void writeTable(PrintStream out, CPF cpf, int i, int[] addr) {
BeliefNode[] domprod = cpf.getDomainProduct();
if(i == domprod.length) {
for(int d = 0; d < domprod[0].getDomain().getOrder(); d++) {
addr[0] = d;
out.printf(" %s", Double.toString(cpf.getDouble(addr)));
}
out.println();
return;
}
int order = domprod[i].getDomain().getOrder();
for(int d = 0; d < order; d++) {
addr[i] = d;
writeTable(out, cpf, i+1, addr);
}
}
}