/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.beliefs.bayes.integration; import org.drools.beliefs.bayes.BayesNetwork; import org.drools.beliefs.bayes.BayesVariable; import org.drools.beliefs.bayes.model.Bif; import org.drools.beliefs.bayes.model.Definition; import org.drools.beliefs.bayes.model.Network; import org.drools.beliefs.bayes.model.Variable; import org.drools.beliefs.bayes.model.XmlBifParser; import org.drools.beliefs.graph.GraphNode; import org.junit.Test; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import static junit.framework.Assert.assertEquals; import static junit.framework.Assert.assertNotNull; import static junit.framework.Assert.assertTrue; import static org.junit.Assert.assertNull; public class ParserTest { @Test public void testSprinklerLoadBif() { Bif bif = (Bif) XmlBifParser.loadBif(ParserTest.class.getResource("Garden.xmlbif")); Network network = bif.getNetwork(); assertEquals( "Garden", network.getName() ); assertEquals( "package = org.drools.beliefs.bayes.integration", network.getProperties().get(0) ); Map<String, Variable> varMap = varToMap( network.getVariables() ); assertEquals(4, varMap.size()); Variable var = varMap.get("WetGrass"); assertEquals("WetGrass", var.getName()); assertEquals(2, var.getOutComes().size()); assertEquals(var.getOutComes(), Arrays.asList(new String[]{"false", "true"})); assertEquals("position = (0,10)", var.getProperties().get(0)); var = varMap.get("Cloudy"); assertEquals( "Cloudy", var.getName()); assertEquals(2, var.getOutComes().size()); assertEquals(var.getOutComes(), Arrays.asList(new String[]{"false", "true"})); assertEquals( "position = (0,-10)", var.getProperties().get(0) ); var = varMap.get("Sprinkler"); assertEquals( "Sprinkler", var.getName()); assertEquals( 2, var.getOutComes().size() ); assertEquals(var.getOutComes(), Arrays.asList(new String[]{"false", "true"})); assertEquals("position = (13,0)", var.getProperties().get(0) ); var = varMap.get("Rain"); assertEquals( "Rain", var.getName()); assertEquals( 2, var.getOutComes().size() ); assertEquals(var.getOutComes(), Arrays.asList(new String[]{"false", "true"})); assertEquals("position = (-12,0)", var.getProperties().get(0) ); Map<String, Definition> defMap = defToMap( network.getDefinitions() ); assertEquals( 4, defMap.size() ); Definition def = defMap.get( "WetGrass"); assertEquals( "WetGrass", def.getName()); assertEquals( 2, def.getGiven().size()); assertEquals(def.getGiven(), Arrays.asList(new String[]{"Sprinkler", "Rain"})); assertEquals("1.0 0.0 0.1 0.9 0.1 0.9 0.01 0.99", def.getProbabilities()); def = defMap.get( "Cloudy"); assertEquals( "Cloudy", def.getName()); assertNull(def.getGiven()); assertEquals("0.5 0.5", def.getProbabilities().trim()); def = defMap.get( "Sprinkler"); assertEquals( "Sprinkler", def.getName()); assertEquals( 1, def.getGiven().size()); assertEquals("Cloudy", def.getGiven().get(0)); assertEquals("0.5 0.5 0.9 0.1", def.getProbabilities().trim()); def = defMap.get( "Rain"); assertEquals( "Rain", def.getName() ); assertNull( def.getGiven()); assertEquals("0.5 0.5", def.getProbabilities().trim()); } @Test public void testSprinklerBuildBayesNework() { Bif bif = (Bif) XmlBifParser.loadBif(ParserTest.class.getResource("Garden.xmlbif")); BayesNetwork network = XmlBifParser.buildBayesNetwork( bif ); Map<String, GraphNode<BayesVariable>> map = nodeToMap(network); GraphNode<BayesVariable> node = map.get( "WetGrass" ); BayesVariable wetGrass = node.getContent(); assertEquals(Arrays.asList(new String[]{"false", "true"}), Arrays.asList(wetGrass.getOutcomes())); assertEquals( 2, wetGrass.getGiven().length ); assertEquals( Arrays.asList( wetGrass.getGiven() ), Arrays.asList( new String[] { "Sprinkler", "Rain" }) ); assertTrue( Arrays.deepEquals( new double[][] { { 1.0, 0.0 }, { 0.1, 0.9 }, { 0.1, 0.9 }, { 0.01, 0.99 } }, wetGrass.getProbabilityTable() ) ); node = map.get( "Sprinkler" ); BayesVariable sprinkler = node.getContent(); assertEquals(Arrays.asList(new String[]{"false", "true"}), Arrays.asList(sprinkler.getOutcomes())); assertEquals( 1, sprinkler.getGiven().length ); assertEquals( "Cloudy", sprinkler.getGiven()[0]); assertTrue( Arrays.deepEquals( new double[][] { {0.5, 0.5}, { 0.9, 0.1} }, sprinkler.getProbabilityTable() ) ); node = map.get( "Cloudy" ); BayesVariable cloudy = node.getContent(); assertEquals(Arrays.asList(new String[]{"false", "true"}), Arrays.asList(cloudy.getOutcomes())); assertEquals(0, cloudy.getGiven().length); assertTrue( Arrays.deepEquals( new double[][] { {0.5, 0.5} }, cloudy.getProbabilityTable() ) ); node = map.get( "Rain" ); BayesVariable rain = node.getContent(); assertEquals(Arrays.asList(new String[]{"false", "true"}), Arrays.asList(rain.getOutcomes())); assertEquals( 0, rain.getGiven().length ); assertTrue( Arrays.deepEquals( new double[][] { {0.5, 0.5} }, rain.getProbabilityTable() ) ); } Map<String, GraphNode<BayesVariable>> nodeToMap(BayesNetwork network) { Map<String, GraphNode<BayesVariable>> map = new HashMap<String, GraphNode<BayesVariable>>(); for ( GraphNode<BayesVariable> node : network ) { map.put( node.getContent().getName(), node ); } return map; } public Map<String, Variable> varToMap(List<Variable> list) { Map<String, Variable> map = new HashMap<String, Variable>(); for ( Variable var : list ) { map.put( var.getName(), var ); } return map; } public Map<String, Definition> defToMap(List<Definition> list) { Map<String, Definition> map = new HashMap<String, Definition>(); for ( Definition def : list ) { map.put( def.getName(), def ); } return map; } }