/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.beliefs.bayes.model; import com.thoughtworks.xstream.XStream; import com.thoughtworks.xstream.io.xml.DomDriver; import org.drools.beliefs.bayes.BayesNetwork; import org.drools.beliefs.bayes.BayesVariable; import org.drools.beliefs.bayes.assembler.BayesNetworkAssemblerError; import org.drools.beliefs.graph.Graph; import org.drools.beliefs.graph.GraphNode; import org.drools.beliefs.graph.impl.EdgeImpl; import org.drools.compiler.compiler.ParserError; import org.drools.core.io.internal.InternalResource; import org.kie.api.io.Resource; import org.kie.internal.builder.KnowledgeBuilderErrors; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; public class XmlBifParser { public static Bif loadBif(Resource resource, KnowledgeBuilderErrors errors) { InputStream is = null; try { is = resource.getInputStream(); } catch (IOException e) { errors.add( new ParserError(resource, "Exception opening Stream:\n" + e.toString(), 0, 0) ); return null; } try { String encoding = resource instanceof InternalResource ? ((InternalResource) resource).getEncoding() : null; XStream xstream; if (encoding != null) { xstream = new XStream(new DomDriver(encoding)); } else { xstream = new XStream(); } initXStream(xstream); Bif bif = (Bif) xstream.fromXML(is); return bif; } catch (Exception e) { errors.add( new BayesNetworkAssemblerError(resource, "Unable to parse opening Stream:\n" + e.toString()) ); return null; } } public static Bif loadBif(URL url) { XStream xstream = new XStream(); initXStream( xstream ); Bif bif = (Bif) xstream.fromXML(url); return bif; } private static void initXStream(XStream xstream) { xstream.processAnnotations(Bif.class); xstream.processAnnotations(Network.class); xstream.processAnnotations(Probability.class); xstream.processAnnotations(Definition.class); } public static BayesNetwork buildBayesNetwork(Bif bif) { String name = bif.getNetwork().getName(); String packageName = "default"; List<String> props = bif.getNetwork().getProperties(); if (props != null ) { for ( String prop : props ) { prop = prop.trim(); if (prop.startsWith("package") ) { packageName = prop.substring( prop.indexOf('=') + 1).trim(); } } } BayesNetwork graph = new BayesNetwork(name, packageName); Map<String, GraphNode<BayesVariable>> map = new HashMap<String, GraphNode<BayesVariable>>(); for (Definition def : bif.getNetwork().getDefinitions()) { GraphNode<BayesVariable> node = graph.addNode(); BayesVariable var = buildVariable(def, bif.getNetwork(), node.getId()); node.setContent( var ); map.put( var.getName(), node ); } for(Entry<String, GraphNode<BayesVariable>> entry : map.entrySet()) { GraphNode<BayesVariable> node = entry.getValue(); BayesVariable var = node.getContent(); if ( var.getGiven() != null && var.getGiven().length > 0 ) { for ( String given : var.getGiven() ) { GraphNode<BayesVariable> givenNode = map.get( given ); EdgeImpl e = new EdgeImpl(); e.setOutGraphNode(givenNode); e.setInGraphNode(node); } } } return graph; } private static BayesVariable buildVariable(Definition def, Network network, int id) { List<String> outcomes = new ArrayList(); getOutcomesByVariable(network, def.getName(), outcomes); List<String> given = (def.getGiven() == null) ? Collections.<String>emptyList() : def.getGiven(); return new BayesVariable<String>(def.getName(), id, outcomes.toArray( new String[ outcomes.size()] ), getProbabilities(def.getProbabilities(), outcomes), given.toArray(new String[given.size()]) ); } private static void getOutcomesByVariable(Network network, String nameDefinition, List<String> outcomes) { for (Variable var : network.getVariables()) { if (var.getName().equals(nameDefinition)) { for (String outcome : var.getOutComes()) { outcomes.add(outcome); } } } } private static double[][] getProbabilities(String table,List<String> outcomes) { table = table.trim(); String[] values = table.split(" "); double probabilities[][] = new double[values.length/2][outcomes.size()]; int k = 0; for(int i = 0, length = values.length/2; i < length; i++){ for(int j = 0; j < outcomes.size(); j++){ probabilities[i][j] = Double.valueOf(values[k++]); } } return probabilities; } private static double[][] getPosition(String stringPosition, double[][] position) { if (stringPosition != null) { stringPosition = clearStringPostion(stringPosition); int i = 0; int j = 0; for (String pos : stringPosition.split(",")) { position[i][j] = Double.parseDouble(pos); if (i < j) { i += 1; } j += 1; } } return null; } private static String clearStringPostion(String stringPosition){ stringPosition = stringPosition.replace("position", ""); stringPosition = stringPosition.replace("=", ""); stringPosition = stringPosition.replace("(", ""); stringPosition = stringPosition.replace(")", ""); stringPosition = stringPosition.trim(); return stringPosition; } // private void setIncomingNodes(BayesNetwork bayesNetwork){ // for(BayesVariable node : bayesNetwork.getNodos()){ // if(node.getGiven()!=null && !node.getGiven().isEmpty()){ // node.setIncomingNodes(this.getNodesByGiven(node.getGiven(), bayesNetwork.getNodos())); // } // } // } // // private List<BayesVariable> getNodesByGiven(List<String> given, List<BayesVariable> nodes){ // List<BayesVariable> listIncoming = new ArrayList(); // for(String giv : given){ // for(BayesVariable node : nodes){ // if(node.getName().equals(giv)){ // listIncoming.add(node); // break; // } // } // } // return listIncoming; // } }