/* * Copyright 2011 JBoss Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.pmml.pmml_4_1.predictive.models; import junit.framework.Assert; import org.drools.ClassObjectFilter; import org.drools.KnowledgeBase; import org.drools.KnowledgeBaseFactory; import org.drools.builder.KnowledgeBuilder; import org.drools.builder.KnowledgeBuilderFactory; import org.drools.builder.ResourceType; import org.drools.definition.type.FactType; import org.drools.informer.Answer; import org.drools.io.ResourceFactory; import org.drools.io.impl.ClassPathResource; import org.drools.pmml.pmml_4_1.DroolsAbstractPMMLTest; import org.drools.pmml.pmml_4_1.ModelMarker; import org.drools.pmml.pmml_4_1.PMML4Compiler; import org.drools.runtime.StatefulKnowledgeSession; import org.drools.runtime.rule.FactHandle; import org.drools.runtime.rule.QueryResults; import org.drools.runtime.rule.Variable; import org.junit.After; import org.junit.Test; import java.util.Collection; import java.util.Iterator; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class NeuralNetworkTest extends DroolsAbstractPMMLTest { private static final boolean VERBOSE = true; private static final String source1 = "org/drools/pmml/pmml_4_1/test_ann_regression.xml"; // private static final String source3 = "org/drools/pmml/pmml_4_1/test_miningSchema.xml"; private static final String source2 = "org/drools/pmml/pmml_4_1/test_ann_iris.xml"; private static final String source22 = "org/drools/pmml/pmml_4_1/test_ann_iris_v2.xml"; private static final String source23 = "org/drools/pmml/pmml_4_1/test_ann_iris_prediction.xml"; private static final String source4 = "org/drools/pmml/pmml_4_1/test_ann_mixed_inputs2.xml"; private static final String source6 = "org/drools/pmml/pmml_4_1/mock_ptsd.xml"; private static final String source7 = "org/drools/pmml/pmml_4_1/mock_cold.xml"; private static final String source8 = "org/drools/pmml/pmml_4_1/mock_breastcancer.xml"; private static final String source9 = "org/drools/pmml/pmml_4_1/test_nn_clax_output.xml"; private static final String packageName = "org.drools.pmml.pmml_4_1.test"; private static final String smartVent = "org/drools/pmml/pmml_4_1/smartvent.xml"; @After public void tearDown() { getKSession().dispose(); } @Test public void testANNFromSource() throws Exception { KnowledgeBuilder knowledgeBuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); knowledgeBuilder.add(ResourceFactory.newClassPathResource("org/drools/informer/informer-changeset.xml"), ResourceType.CHANGE_SET); knowledgeBuilder.add( new ClassPathResource("org/drools/pmml/pmml_4_1/ann_rules.drl"), ResourceType.DRL ); if ( knowledgeBuilder.hasErrors() ) { fail(knowledgeBuilder.getErrors().toString()); } KnowledgeBase kBase = KnowledgeBaseFactory.newKnowledgeBase(); kBase.addKnowledgePackages( knowledgeBuilder.getKnowledgePackages() ); StatefulKnowledgeSession kSession = kBase.newStatefulKnowledgeSession(); setKSession( kSession ); // kSession.addEventListener( new DebugAgendaEventListener() ); // kSession.addEventListener( new DebugWorkingMemoryEventListener() ); kSession.fireAllRules(); //init model kSession.getWorkingMemoryEntryPoint("in_Gender").insert("male"); kSession.getWorkingMemoryEntryPoint("in_NoOfClaims").insert("3"); kSession.getWorkingMemoryEntryPoint("in_Scrambled").insert(7); kSession.getWorkingMemoryEntryPoint("in_Domicile").insert("urban"); kSession.getWorkingMemoryEntryPoint("in_AgeOfCar").insert(8.0); kSession.fireAllRules(); Thread.sleep(200); System.err.println(reportWMObjects(kSession)); } @Test public void testANN() throws Exception { setKSession(getModelSession(source1,VERBOSE)); setKbase(getKSession().getKnowledgeBase()); getKSession().fireAllRules(); //init model Assert.assertEquals(33, getNumAssertedSynapses()); getKSession().getWorkingMemoryEntryPoint("in_Gender").insert("male"); getKSession().getWorkingMemoryEntryPoint("in_NoOfClaims").insert("3"); getKSession().getWorkingMemoryEntryPoint("in_Scrambled").insert(7); getKSession().getWorkingMemoryEntryPoint("in_Domicile").insert("urban"); getKSession().getWorkingMemoryEntryPoint("in_AgeOfCar").insert(8.0); getKSession().fireAllRules(); Thread.sleep(200); System.err.println(reportWMObjects(getKSession())); Assert.assertEquals( 828.0, Math.floor( queryDoubleField( "OutAmOfClaims", "NeuralInsurance" ) ) ); } @Test public void testANNCompilation() throws Exception { setKSession( getModelSession( source3, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); } @Test public void testCold() throws Exception { setKSession( getModelSession( source7, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); getKSession().fireAllRules(); //init model getKSession().getWorkingMemoryEntryPoint( "in_Temp" ).insert( 28.0 ); getKSession().fireAllRules(); System.err.println( reportWMObjects( getKSession() ) ); Assert.assertEquals( 0.44, queryDoubleField( "Cold", "MockCold" ), 1e-6 ); } @Test public void testClearOutput() throws Exception { setKSession( getModelSession( source7, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); getKSession().fireAllRules(); //init model getKSession().getWorkingMemoryEntryPoint( "in_Temp" ).insert( 28.0 ); getKSession().fireAllRules(); Assert.assertEquals( 0.44, queryDoubleField( "Cold", "MockCold" ), 1e-6 ); for ( Object o : getKSession().getObjects() ) { System.out.println( o ); } FactType tempKlass = getKSession().getKnowledgeBase().getFactType( "org.drools.pmml.pmml_4_1.test", "Temp" ); Collection temps = getKSession().getObjects( new ClassObjectFilter( tempKlass.getFactClass() ) ); Iterator iter = temps.iterator(); Object temp = iter.next(); if ( tempKlass.get( temp, "value" ) != null ) { temp = iter.next(); } getKSession().retract( getKSession().getFactHandle( temp ) ); getKSession().fireAllRules(); QueryResults results = getKSession().getQueryResults( "Cold", "MockCold", Variable.v ); assertEquals( 0, results.size() ); } @Test public void testPTSD() throws Exception { setKSession(getModelSession(source6,VERBOSE)); setKbase(getKSession().getKnowledgeBase()); getKSession().fireAllRules(); //init model getKSession().getWorkingMemoryEntryPoint("in_Gender").insert("male"); getKSession().getWorkingMemoryEntryPoint("in_Alcohol").insert("yes"); getKSession().getWorkingMemoryEntryPoint("in_Deployments").insert("1"); // getKSession().getWorkingMemoryEntryPoint("in_Age").insert(30.2); getKSession().fireAllRules(); Answer ans2 = new Answer( getQId( "MockPTSD", "Age" ),"30.2" ); getKSession().insert(ans2); getKSession().fireAllRules(); Thread.sleep(200); System.err.println(reportWMObjects(getKSession())); Assert.assertEquals( 0.2802, queryDoubleField( "PTSD", "MockPTSD" ) ); assertEquals( 1, getKSession().getObjects( new ClassObjectFilter( ModelMarker.class) ).size() ); } @Test public void testBreastCancer() throws Exception { setKSession( getModelSession( source8, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); getKSession().fireAllRules(); //init model getKSession().getWorkingMemoryEntryPoint("in_Menses").insert("Unknown"); getKSession().getWorkingMemoryEntryPoint("in_Relatives").insert("Unknown"); getKSession().getWorkingMemoryEntryPoint("in_Biopsy").insert("Unknown"); getKSession().fireAllRules(); Assert.assertEquals( 0.15, queryDoubleField( "BreastCancer", "MockBC" ), 1e-6 ); Answer ans = new Answer( getQId( "MockBC", "Menses" ),"7-11" ); getKSession().insert( ans ); getKSession().fireAllRules(); Assert.assertEquals( 0.18, queryDoubleField( "BreastCancer", "MockBC" ), 1e-6 ); Answer ans2 = new Answer( getQId( "MockBC", "Relatives" ),"2+" ); getKSession().insert( ans2 ); getKSession().fireAllRules(); Assert.assertEquals( 0.34, queryDoubleField( "BreastCancer", "MockBC" ), 1e-6 ); Answer ans3 = new Answer( getQId( "MockBC", "Biopsy" ),"Yes" ); getKSession().insert( ans3 ); getKSession().fireAllRules(); Assert.assertEquals( 0.52, queryDoubleField( "BreastCancer", "MockBC" ), 1e-6 ); // System.err.println( reportWMObjects( getKSession() ) ); } private String getQId( String model, String field ) { // ref : getItemId( String $type, String $context, String $id ) String questId = (String) getKSession().getQueryResults( "getItemId", model, Variable.v, Variable.v ).iterator().next().get( "$id" ); return (String) getKSession().getQueryResults( "getItemId", model+"_"+field, questId, Variable.v ).iterator().next().get( "$id" ); } @Test public void testIris() throws Exception { setKSession( getModelSession( source2, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); getKSession().fireAllRules(); //init model Assert.assertEquals(21, getNumAssertedSynapses()); getKSession().getWorkingMemoryEntryPoint("in_PetalLen").insert(2.2); getKSession().getWorkingMemoryEntryPoint("in_PetalWid").insert(4.1); getKSession().getWorkingMemoryEntryPoint("in_SepalLen").insert(2.3); getKSession().getWorkingMemoryEntryPoint("in_SepalWid").insert(1.8); getKSession().fireAllRules(); System.err.println(reportWMObjects(getKSession())); FactType t7 = getKbase().getFactType( packageName, "Test_MLP_7" ); FactType t8 = getKbase().getFactType( packageName, "Test_MLP_8" ); FactType t9 = getKbase().getFactType( packageName, "Test_MLP_9" ); FactType s1 = getKbase().getFactType( packageName, "Cspecies_virginica" ); Assert.assertEquals(0.001, truncN(getDoubleFieldValue( t7 ), 3), 1e-4); Assert.assertEquals(0.282, truncN(getDoubleFieldValue( t8 ), 3), 1e-4); Assert.assertEquals(0.716, truncN(getDoubleFieldValue( t9 ), 3), 1e-4); // Assert.assertEquals("virginica", // getFieldValue("Cspecies_virginica", "Test_MLP")); // Assert.assertEquals("Test_setosa", // getFieldValue("Cspecies_setosa", "Test_MLP")); // Assert.assertEquals("Test_versicolor", // getFieldValue("Cspecies_versicolor", "Test_MLP")); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"SpecSetosa"), true, false,"Test_MLP",0.001111); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"SpecVirgin"), true, false,"Test_MLP",0.716639); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"SpecVersic"), true, false,"Test_MLP",0.282249); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"SpecOut"), true, false,"Test_MLP","virginica"); } @Test public void testIris2() throws Exception { setKSession(getModelSession(source22,VERBOSE)); setKbase(getKSession().getKnowledgeBase()); getKSession().fireAllRules(); //init model Assert.assertEquals(12, getNumAssertedSynapses()); getKSession().getWorkingMemoryEntryPoint("in_PetalLen").insert(101); getKSession().getWorkingMemoryEntryPoint("in_PetalWid").insert(1); getKSession().getWorkingMemoryEntryPoint("in_SepalLen").insert(151); getKSession().getWorkingMemoryEntryPoint("in_SepalWid").insert(30); getKSession().fireAllRules(); System.err.println(reportWMObjects(getKSession())); FactType t4 = getKbase().getFactType( packageName, "Test_MLP_0" ); FactType t5 = getKbase().getFactType( packageName, "Test_MLP_1" ); FactType t6 = getKbase().getFactType( packageName, "Test_MLP_2" ); Assert.assertEquals(1.542, truncN(getDoubleFieldValue(t4), 3)); Assert.assertEquals(0.0, truncN(getDoubleFieldValue(t5), 3)); Assert.assertEquals(3.0, truncN(getDoubleFieldValue(t6), 3)); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"OutSpecies"), true, false,"Test_MLP","versicolor"); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"OutProb"), true, false,"Test_MLP",0.999999); } @Test public void testIris3() throws Exception { setKSession(getModelSession(source23,VERBOSE)); setKbase(getKSession().getKnowledgeBase()); getKSession().fireAllRules(); //init model Assert.assertEquals(6, getNumAssertedSynapses()); getKSession().getWorkingMemoryEntryPoint("in_PetalNum").insert(101); getKSession().getWorkingMemoryEntryPoint("in_PetalWid").insert(2); getKSession().getWorkingMemoryEntryPoint("in_Species").insert("virginica"); getKSession().getWorkingMemoryEntryPoint("in_SepalWid").insert(30); getKSession().fireAllRules(); System.err.println(reportWMObjects(getKSession())); Assert.assertEquals(24.0, queryIntegerField("OutSepLen", "Neuiris")); } @Test public void testSimpleANN() throws Exception { // from mining schema test, simple network with fieldRef as output setKSession( getModelSession( source3, VERBOSE ) ); setKbase(getKSession().getKnowledgeBase()); getKSession().getWorkingMemoryEntryPoint( "in_Feat2" ).insert( 4 ); getKSession().getWorkingMemoryEntryPoint( "in_Feat1" ).insert( 3.5 ); getKSession().fireAllRules(); System.err.println( reportWMObjects( getKSession() ) ); checkFirstDataFieldOfTypeStatus( getKbase().getFactType( packageName, "MockOutput2" ), true, false, "Test_MLP",1.0 ); checkFirstDataFieldOfTypeStatus( getKbase().getFactType( packageName, "MockOutput1" ), true, false, "Test_MLP",0.0 ); } @Test public void testHeart() throws Exception { setKSession(getModelSession(source4,VERBOSE)); setKbase(getKSession().getKnowledgeBase()); getKSession().fireAllRules(); //init model Assert.assertEquals(81, getNumAssertedSynapses()); getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(83.0); getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(1.0); getKSession().getWorkingMemoryEntryPoint("in_Feat3").insert(5.0); getKSession().getWorkingMemoryEntryPoint("in_Feat4").insert("asympt"); getKSession().getWorkingMemoryEntryPoint("in_Feat5").insert("yes"); getKSession().getWorkingMemoryEntryPoint("in_Feat6").insert("t"); getKSession().getWorkingMemoryEntryPoint("in_Feat7").insert(1.0); getKSession().getWorkingMemoryEntryPoint("in_Feat8").insert("normal"); getKSession().getWorkingMemoryEntryPoint("in_Feat9").insert("male"); getKSession().getWorkingMemoryEntryPoint("in_Feat10").insert("flat"); getKSession().getWorkingMemoryEntryPoint("in_Feat11").insert("normal"); getKSession().getWorkingMemoryEntryPoint("in_Feat12").insert(3.3); getKSession().getWorkingMemoryEntryPoint("in_Feat13").insert(2.5); getKSession().fireAllRules(); System.err.println(reportWMObjects(getKSession())); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"OutN"), true, false,"HEART_MLP",">50_1"); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"OutP"), true, false,"HEART_MLP",0.943336); } @Test public void testOverride() throws Exception { setKSession(getModelSession(source3,VERBOSE)); setKbase(getKSession().getKnowledgeBase()); getKSession().fireAllRules(); getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(2.2); getKSession().fireAllRules(); getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(5); getKSession().fireAllRules(); System.err.println(reportWMObjects(getKSession())); FactType out1 = getKbase().getFactType("org.drools.pmml.pmml_4_1.test","Out1"); FactType out2 = getKbase().getFactType("org.drools.pmml.pmml_4_1.test","Out2"); FactType nump = getKbase().getFactType("org.drools.pmml.pmml_4_1.test","Feat2"); assertEquals(1,getKSession().getObjects(new ClassObjectFilter(out1.getFactClass())).size()); assertEquals(1,getKSession().getObjects(new ClassObjectFilter(out2.getFactClass())).size()); assertEquals(2,getKSession().getObjects(new ClassObjectFilter(nump.getFactClass())).size()); getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(2.5); getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(6); getKSession().fireAllRules(); System.err.println(reportWMObjects(getKSession())); assertEquals(1,getKSession().getObjects(new ClassObjectFilter(out1.getFactClass())).size()); assertEquals(1,getKSession().getObjects(new ClassObjectFilter(out2.getFactClass())).size()); assertEquals(2,getKSession().getObjects(new ClassObjectFilter(nump.getFactClass())).size()); } @Test public void testSmartVent() throws Exception { setKSession( getModelSession( smartVent, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); getKSession().fireAllRules(); //init model getKSession().getWorkingMemoryEntryPoint("in_PIP").insert(28.0); getKSession().getWorkingMemoryEntryPoint("in_PEEP").insert(5.0); getKSession().getWorkingMemoryEntryPoint("in_RATE").insert(30.0); getKSession().getWorkingMemoryEntryPoint("in_IT").insert(0.4); getKSession().getWorkingMemoryEntryPoint("in_Ph").insert(7.281); getKSession().getWorkingMemoryEntryPoint("in_CO2").insert(39.3); getKSession().getWorkingMemoryEntryPoint("in_PaO2").insert(126.0); getKSession().getWorkingMemoryEntryPoint("in_FIO2").insert(100.0); getKSession().fireAllRules(); System.err.println( reportWMObjects( getKSession() ) ); assertEquals( 24.0, queryDoubleField("Out_sPIP", "SmartVent"), 0.5 ); assertEquals( 5, queryDoubleField("Out_sPEEP", "SmartVent"), 0.1 ); assertEquals( 30, queryDoubleField("Out_sRATE", "SmartVent"), 0.5 ); assertEquals( 0.4, queryDoubleField("Out_sIT", "SmartVent"), 0.05 ); assertEquals( -1, queryDoubleField("Out_sFIO2", "SmartVent"), 0.05 ); getKSession().getWorkingMemoryEntryPoint("in_RATE").insert(20.0); getKSession().getWorkingMemoryEntryPoint("in_PaO2").insert(75.0); getKSession().getWorkingMemoryEntryPoint("in_Ph").insert(7.31); getKSession().getWorkingMemoryEntryPoint("in_CO2").insert(37.0); getKSession().getWorkingMemoryEntryPoint("in_IT").insert(0.4); getKSession().getWorkingMemoryEntryPoint("in_PIP").insert(20.0); getKSession().getWorkingMemoryEntryPoint("in_PEEP").insert(4.0); getKSession().getWorkingMemoryEntryPoint("in_FIO2").insert(38.0); getKSession().fireAllRules(); System.err.println( reportWMObjects( getKSession() ) ); assertEquals( 18, queryDoubleField("Out_sPIP", "SmartVent"), 0.5 ); assertEquals( 4.12, queryDoubleField("Out_sPEEP", "SmartVent"), 0.1 ); assertEquals( 19, queryDoubleField("Out_sRATE", "SmartVent"), 0.5 ); assertEquals( 0.4, queryDoubleField("Out_sIT", "SmartVent"), 0.05 ); assertEquals( -1, queryDoubleField("Out_sFIO2", "SmartVent"), 0.05 ); } @Test public void testClaxOutput() throws Exception { setKSession( getModelSession( source9, true ) ); setKbase( getKSession().getKnowledgeBase() ); getKSession().fireAllRules(); //init model getKSession().getWorkingMemoryEntryPoint( "in_Temp" ).insert(28.0); getKSession().fireAllRules(); Thread.sleep(200); System.err.println( reportWMObjects( getKSession() ) ); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"ColdCat"), true, false,"MockCold","SURE"); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"ColdYES"), true, false,"MockCold",0.6475435612444598); checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"ColdNO"), true, false,"MockCold",0.0036540476859388943); } private int getNumAssertedSynapses() { Class<?> synClass = getKSession().getKnowledgeBase().getFactType(packageName,"Synapse").getFactClass(); return getKSession().getObjects(new ClassObjectFilter(synClass)).size(); } private double truncN(double x, int numDecimal) { return (Math.floor(x * Math.pow(10,numDecimal))) * Math.pow(10,-numDecimal); } }