/******************************************************************************* * Copyright (C) 2007-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ // based on the class Converter_xmlbif from KSU's BNJ (Bayesian Network Tools in Java) package probcog.bayesnets.core.io; import edu.ksu.cis.bnj.ver3.core.BeliefNetwork; import edu.ksu.cis.bnj.ver3.streams.*; import java.io.*; import java.util.*; import java.util.Map.Entry; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import org.w3c.dom.*; /** * Converter (Exporter/Importer) for a PMML-based format (PMML 3.0 with custom extensions) * This class need not be used directly. BeliefNetworkEx implements the loading and storing * of PMML files using this class. Converters of this kind form the basis for BNJ export/import * plugins – the PMML plugin is made available, too. * (This class is largely based upon Converter_xmlbif, which is part of BNJ) * @author Dominik Jain */ public class Converter_pmml implements OmniFormatV1, Exporter, Importer { protected OmniFormatV1 _Writer; protected int bn_cnt; private int bnode_cnt; // saving protected HashMap<Integer, NodeData> nodeData; protected Writer w; public int netDepth; protected int curNodeIdx; //protected HashMap adjList; // loading protected HashMap<Integer, String> nodeNames; //protected HashMap<String, Integer> nodeIndices; protected HashMap<Integer, Integer> nodeIndices; // maps node IDs to node indices protected NodeData curNode; HashMap<Integer, Node> cptTags; // omiformat protected StringBuffer cpf; protected int cpfNodeID; public Converter_pmml() { w = null; curNodeIdx = 0; } public OmniFormatV1 getStream1() { return this; } // ************************************************ // ***************** LOADING ********************** // ************************************************ public void load(InputStream stream, OmniFormatV1 writer) { _Writer = writer; // this is usually an instance of OmniFormatV1_Reader _Writer.Start(); bn_cnt = 0; bnode_cnt = 0; nodeIndices = new HashMap<Integer, Integer>(); DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); //factory.setValidating(true); factory.setNamespaceAware(true); org.w3c.dom.Document doc; try { DocumentBuilder parser = factory.newDocumentBuilder(); doc = parser.parse(stream); } catch(Exception e) { throw new RuntimeException(e); } visitDocument(doc); System.gc(); } public void visitDocument(Node parent) { NodeList l = parent.getChildNodes(); if(l == null) throw new RuntimeException("Unexpected end of document!"); int max = l.getLength(); for(int i = 0; i < max; i++) { Node node = l.item(i); switch(node.getNodeType()) { case 1: // '\001' String name = node.getNodeName(); if(name.equals("PMML")) { // process PMML attributes NamedNodeMap attrs = node.getAttributes(); if(attrs != null) { int amax = attrs.getLength(); for(int j = 0; j < amax; j++) { Node attr = attrs.item(j); String aname = attr.getNodeName().toUpperCase(); if(aname.equals("version")) try { if(!aname.equals("3.0")) throw new RuntimeException("PMML version " + aname + " is not supported"); } catch(Exception e) { } //else // System.out.println("property:" + aname + " not handled"); } } // read child nodes cptTags = new HashMap<Integer, Node>(); visitDocument(node); // process cpt definitions for nodes that were gathered along the way //System.out.println("processing CPTs"); for(Entry<Integer,Node> e : cptTags.entrySet()) { visitDefinition(e.getValue(), e.getKey()); } cptTags = null; } else if(name.equals("DataDictionary")) { _Writer.CreateBeliefNetwork(bn_cnt); visitDataDict(node); bn_cnt++; } /*else { throw new RuntimeException("Unhandled element " + name); }*/ } } } public void visitDataDict(Node parent) { // process children (DataFields) NodeList l = parent.getChildNodes(); if(l == null) throw new RuntimeException("Unexpected end of document!"); int max = l.getLength(); for(int i = 0; i < max; i++) { Node node = l.item(i); switch(node.getNodeType()) { case 1: // '\001' String name = node.getNodeName(); if(name.equals("DataField")) { visitDataField(node); bnode_cnt++; } } } } /** * a <DataField> contains all the data on one node in the BN * @param parent */ protected void visitDataField(Node parent) { // read attributes NamedNodeMap attrs = parent.getAttributes(); String nodeName = null; Integer nodeID = null; int max; if(attrs != null) { max = attrs.getLength(); for(int i = 0; i < max; i++) { Node attr = attrs.item(i); String attrName = attr.getNodeName(); String value = attr.getNodeValue(); if(attrName.equals("name")) { nodeName = value; //nodeIndices.put(nodeName, new Integer(bnode_cnt)); } if(attrName.equals("id")) { nodeID = Integer.parseInt(value); } /*else { System.out.println("Unhandled variable property attribute " + name); }*/ } } if(nodeName == null || nodeID == null) throw new RuntimeException("Missing DataField attribute 'name' or 'id'!"); nodeIndices.put(nodeID, new Integer(bnode_cnt)); _Writer.BeginBeliefNode(bnode_cnt); _Writer.SetBeliefNodeName(nodeName); // process child tags NodeList l = parent.getChildNodes(); max = l.getLength(); for(int i = 0; i < max; i++) { Node node = l.item(i); switch(node.getNodeType()) { case 1: // '\001' String name = node.getNodeName(); if(name.equals("Value")) { attrs = node.getAttributes(); for(int j = attrs.getLength()-1; j >= 0; j--) { Node attr = attrs.item(j); if(attr.getNodeName().equals("value")) _Writer.BeliefNodeOutcome(attr.getNodeValue()); } } else if(name.equals("Extension")) { NodeList l_ext = node.getChildNodes(); for(int j = 0; j < l_ext.getLength(); j++) { Node n = l_ext.item(j); if(n.getNodeName().equals("X-NodeType")) { _Writer.SetType(getElementValue(n)); } else if(n.getNodeName().equals("X-Position")) { attrs = n.getAttributes(); int xPos = 0, yPos = 0; for(int k = attrs.getLength()-1; k >= 0; k--) { Node attr = attrs.item(k); if(attr.getNodeName().equals("x")) xPos = Integer.parseInt(attr.getNodeValue()); else if(attr.getNodeName().equals("y")) yPos = Integer.parseInt(attr.getNodeValue()); } _Writer.SetBeliefNodePosition(xPos, yPos); } else if(n.getNodeName().equals("X-Definition")) cptTags.put(nodeID, n); // remember the X-Definition node for later } } break; } } _Writer.EndBeliefNode(); } protected void visitDefinition(Node definition, int nodeID) { NodeList l = definition.getChildNodes(); if(l == null) return; LinkedList<Integer> parents = new LinkedList<Integer>(); int curNode = nodeIndices.get(nodeID); //nodeIndices.get(nodeName).intValue(); String CPTString = ""; int max = l.getLength(); for(int i = 0; i < max; i++) { Node node = l.item(i); switch(node.getNodeType()) { case 1: // '\001' String name = node.getNodeName(); if(name.equals("X-Given")) { parents.add(nodeIndices.get(Integer.parseInt(getElementValue(node)))); } else if(name.equals("X-Table")) CPTString = getElementValue(node); } } if(curNode >= 0) { for(Integer p : parents) { _Writer.Connect(p, curNode); } _Writer.BeginCPF(curNode); StringTokenizer tok = new StringTokenizer(CPTString); int maxz = tok.countTokens(); for(int c = 0; c < maxz; c++) { String SSS = tok.nextToken(); _Writer.ForwardFlat_CPFWriteValue(SSS); } _Writer.EndCPF(); } } protected String getElementValue(Node parent) { NodeList l = parent.getChildNodes(); if(l == null) return null; StringBuffer buf = new StringBuffer(); int max = l.getLength(); for(int i = 0; i < max; i++) { Node node = l.item(i); switch(node.getNodeType()) { case 3: // '\003' buf.append(node.getNodeValue()); break; default: System.out.println("Unhandled node " + node.getNodeName()); break; case 1: // '\001' case 8: // '\b' break; } } return buf.toString().trim(); } // ************************************************ // ***************** SAVING *********************** // ************************************************ protected class NodeData { public String cpfData, subElements, nodeType, opType, name, domainClassName; int index; int xPos, yPos; Vector<Integer> parents; public NodeData() { cpfData = new String(); subElements = new String(); parents = new Vector<Integer>(); } } public void save(BeliefNetwork bn, OutputStream os) { w = new OutputStreamWriter(os); OmniFormatV1_Writer.Write(bn, this); } public void fwrite(String x) { try { w.write(x); w.flush(); } catch(Exception e) { System.out.println("unable to write?"); } } public void Start() { netDepth = 0; nodeNames = new HashMap<Integer, String>(); //adjList = new HashMap(); fwrite("<?xml version=\"1.0\" encoding=\"US-ASCII\"?>\n"); fwrite("<!-- Bayesian network in a PMML-based format -->\n"); fwrite("<PMML version=\"3.0\" xmlns=\"http://www.dmg.org/PMML-3_0\">\n"); fwrite("\t<Header copyright=\"Technische Universitaet Muenchen\" />\n"); } public void CreateBeliefNetwork(int idx) { if(netDepth > 0) { netDepth = 0; fwrite("\t</DataDictionary>\n"); } nodeData = new HashMap<Integer,NodeData>(); fwrite("\t<DataDictionary>\n"); netDepth = 1; } public void SetBeliefNetworkName(int idx, String name) { //fwrite("<NAME>" + name + "</NAME>\n"); } public void BeginBeliefNode(int idx) { curNode = new NodeData(); curNode.index = idx; curNodeIdx = idx; //adjList.put(new Integer(curNodeIdx), new ArrayList()); } public void SetType(String type) { curNode.nodeType = type; if(type.equals("utility")) curNode.opType = "continuous"; else curNode.opType = "categorical"; } public void SetBeliefNodePosition(int x, int y) { curNode.xPos = x; curNode.yPos = y; } public void SetBeliefNodeDomainClass(String domainClassName) { curNode.domainClassName = domainClassName; } public void BeliefNodeOutcome(String outcome) { curNode.subElements += "\t\t\t<Value value=\"" + outcome.replaceAll("<", "<").replaceAll(">", ">") + "\" />\n"; } public void SetBeliefNodeName(String name) { //System.out.println(name); curNode.name = name; nodeNames.put(new Integer(curNodeIdx), name); } public void MakeContinuous(String s) { } public void EndBeliefNode() { nodeData.put(curNode.index, curNode); } public void Connect(int par_idx, int chi_idx) { nodeData.get(chi_idx).parents.add(par_idx); } public void BeginCPF(int idx) { //System.out.println("CPF: " + nodeData.get(idx).name); cpfNodeID = idx; cpf = new StringBuffer("\t\t\t\t<X-Definition>\n"); String gname; for(Integer given : nodeData.get(idx).parents) { //Integer given = (Integer)it.next(); gname = (String)nodeNames.get(given); cpf.append("\t\t\t\t\t<X-Given>" + given + "</X-Given> <!-- " + gname + " -->\n"); } cpf.append("\t\t\t\t\t<X-Table>"); } public void ForwardFlat_CPFWriteValue(String x) { cpf.append(x + " "); } public void EndCPF() { cpf.append("</X-Table>\n"); cpf.append("\t\t\t\t</X-Definition>\n"); NodeData d = (NodeData)nodeData.get(cpfNodeID); d.cpfData = cpf.toString(); //System.out.println("done."); } public int GetCPFSize() { return 0; } public void Finish() { if(netDepth > 0) { // output all node data Iterator<NodeData> i = nodeData.values().iterator(); while(i.hasNext()) { NodeData nd = i.next(); fwrite("\t\t<DataField name=\"" + nd.name + "\" optype=\"" + nd.opType + "\" id=\"" + nd.index + "\">\n"); fwrite("\t\t\t<Extension>\n"); fwrite("\t\t\t\t<X-NodeType>" + nd.nodeType + "</X-NodeType>\n"); if (nd.domainClassName != null) fwrite("\t\t\t\t<X-NodeDomainClass>" + nd.domainClassName + "</X-NodeDomainClass>\n"); fwrite("\t\t\t\t<X-Position x=\"" + nd.xPos + "\" y=\"" + nd.yPos + "\" />\n"); fwrite(nd.cpfData); fwrite("\t\t\t</Extension>\n"); fwrite(nd.subElements); fwrite("\t\t</DataField>\n"); } netDepth = 0; fwrite("\t</DataDictionary>\n"); } fwrite("</PMML>\n"); try { w.close(); } catch(Exception exception) { } } // --------- UI related --------- public String getExt() { return "*.pmml"; } public String getDesc() { return "PMML 3.0"; } }