/* * Copyright 2011 JBoss Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.pmml.pmml_4_1.predictive.models; import org.drools.KnowledgeBase; import org.drools.KnowledgeBaseFactory; import org.drools.RuleBaseConfiguration; import org.drools.builder.KnowledgeBuilder; import org.drools.builder.KnowledgeBuilderFactory; import org.drools.builder.ResourceType; import org.drools.conf.EventProcessingOption; import org.drools.definition.type.FactType; import org.drools.io.ResourceFactory; import org.drools.io.impl.ByteArrayResource; import org.drools.pmml.pmml_4_1.DroolsAbstractPMMLTest; import org.drools.pmml.pmml_4_1.PMML4Compiler; import org.dmg.pmml.pmml_4_1.descr.*; import org.drools.runtime.ClassObjectFilter; import org.drools.runtime.StatefulKnowledgeSession; import org.junit.After; import org.junit.Test; import java.util.Collection; import static org.junit.Assert.*; public class DecisionTreeTest extends DroolsAbstractPMMLTest { private static final boolean VERBOSE = true; private static final String source1 = "org/drools/pmml/pmml_4_1/test_tree_simple.xml"; private static final String source2 = "org/drools/pmml/pmml_4_1/test_tree_missing.xml"; private static final String packageName = "org.drools.pmml.pmml_4_1.test"; @After public void tearDown() { getKSession().dispose(); } @Test public void testSimpleTree() throws Exception { setKSession( getModelSession( source1, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); StatefulKnowledgeSession kSession = getKSession(); // kSession.addEventListener( new org.drools.event.rule.DebugAgendaEventListener() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld5" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( 30.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( 60.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "false" ); kSession.getWorkingMemoryEntryPoint( "in_Fld4" ).insert( "optA" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtY" ); } protected Object getToken( StatefulKnowledgeSession kSession ) { FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); assertNotNull( tok ); Collection c = kSession.getObjects( new ClassObjectFilter( tok.getFactClass() ) ); assertEquals( 1, c.size() ); return c.iterator().next(); } @Test public void testMissingTree() throws Exception { setKSession( getModelSession( source2, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); StatefulKnowledgeSession kSession = getKSession(); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( 45.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( 60.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "optA" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.6, tok.get( token, "confidence" ) ); assertEquals( "null", tok.get( token, "current" ) ); checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtZ" ); } @Test public void testMissingTreeWeighted1() throws Exception { setKSession( getModelSession( source2, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); StatefulKnowledgeSession kSession = getKSession(); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "optA" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.8, tok.get( token, "confidence" ) ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 50.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" ); } @Test public void testMissingTreeWeighted2() throws Exception { setKSession( getModelSession( source2, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); StatefulKnowledgeSession kSession = getKSession(); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "miss" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.6, tok.get( token, "confidence" ) ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 100.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" ); } private StatefulKnowledgeSession compile( String drl ) { KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); kbuilder.add( new ByteArrayResource( drl.getBytes() ), ResourceType.DRL ); if ( kbuilder.hasErrors() ) { fail( kbuilder.getErrors().toString() ); } RuleBaseConfiguration conf = new RuleBaseConfiguration(); conf.setEventProcessingMode( EventProcessingOption.STREAM ); KnowledgeBase kBase = KnowledgeBaseFactory.newKnowledgeBase( conf ); kBase.addKnowledgePackages( kbuilder.getKnowledgePackages() ); return kBase.newStatefulKnowledgeSession(); } @Test public void testMissingTreeDefault() throws Exception { PMML4Compiler compiler = new PMML4Compiler(); PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() ); for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) { if ( o instanceof TreeModel ) { TreeModel tree = (TreeModel) o; tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.DEFAULT_CHILD ); } } String theory = compiler.generateTheory( pmml ); if ( VERBOSE ) { System.out.println( theory ); } StatefulKnowledgeSession kSession = compile( theory ); setKSession( kSession ); setKbase( getKSession().getKnowledgeBase() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( 70.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( 40.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "miss" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.72, (Double) tok.get( token, "confidence" ), 1e-6 ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 40.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" ); } @Test public void testMissingTreeAllMissingDefault() throws Exception { PMML4Compiler compiler = new PMML4Compiler(); PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() ); for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) { if ( o instanceof TreeModel ) { TreeModel tree = (TreeModel) o; tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.DEFAULT_CHILD ); } } String theory = compiler.generateTheory( pmml ); if ( VERBOSE ) { System.out.println( theory ); } StatefulKnowledgeSession kSession = compile( theory ); setKSession( kSession ); setKbase( getKSession().getKnowledgeBase() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "miss" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 1.0, (Double) tok.get( token, "confidence" ), 1e-6 ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 0.0, tok.get( token, "totalCount" ) ); // checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" ); } @Test public void testMissingTreeLastChoice() throws Exception { PMML4Compiler compiler = new PMML4Compiler(); PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() ); for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) { if ( o instanceof TreeModel ) { TreeModel tree = (TreeModel) o; tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.LAST_PREDICTION ); } } String theory = compiler.generateTheory( pmml ); if ( VERBOSE ) { System.out.println( theory ); } StatefulKnowledgeSession kSession = compile( theory ); setKSession( kSession ); setKbase( getKSession().getKnowledgeBase() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "optA" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.8, (Double) tok.get( token, "confidence" ), 1e-6 ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 50.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtX" ); } @Test public void testMissingTreeNull() throws Exception { PMML4Compiler compiler = new PMML4Compiler(); PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() ); for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) { if ( o instanceof TreeModel ) { TreeModel tree = (TreeModel) o; tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.NULL_PREDICTION ); } } String theory = compiler.generateTheory( pmml ); if ( VERBOSE ) { System.out.println( theory ); } StatefulKnowledgeSession kSession = compile( theory ); setKSession( kSession ); setKbase( getKSession().getKnowledgeBase() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "optA" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.0, (Double) tok.get( token, "confidence" ), 1e-6 ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 0.0, tok.get( token, "totalCount" ) ); assertEquals( 0, getKSession().getObjects( new ClassObjectFilter( tgt.getFactClass() ) ).size() ); } @Test public void testMissingAggregate() throws Exception { PMML4Compiler compiler = new PMML4Compiler(); PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() ); for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) { if ( o instanceof TreeModel ) { TreeModel tree = (TreeModel) o; tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.AGGREGATE_NODES ); } } String theory = compiler.generateTheory( pmml ); if ( VERBOSE ) { System.out.println( theory ); } StatefulKnowledgeSession kSession = compile( theory ); setKSession( kSession ); setKbase( getKSession().getKnowledgeBase() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( 45.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( 90.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "miss" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.47, (Double) tok.get( token, "confidence" ), 1e-2 ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 60.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtY" ); } @Test public void testMissingTreeNone() throws Exception { PMML4Compiler compiler = new PMML4Compiler(); PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() ); for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) { if ( o instanceof TreeModel ) { TreeModel tree = (TreeModel) o; tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.NONE ); } } String theory = compiler.generateTheory( pmml ); if ( VERBOSE ) { System.out.println( theory ); } StatefulKnowledgeSession kSession = compile( theory ); setKSession( kSession ); setKbase( getKSession().getKnowledgeBase() ); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "miss" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.6, (Double) tok.get( token, "confidence" ), 1e-6 ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 100.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtX" ); } @Test public void testSimpleTreeOutput() throws Exception { setKSession( getModelSession( source2, VERBOSE ) ); setKbase( getKSession().getKnowledgeBase() ); StatefulKnowledgeSession kSession = getKSession(); kSession.fireAllRules(); //init model FactType tgt = kSession.getKnowledgeBase().getFactType( packageName, "Fld9" ); FactType tok = kSession.getKnowledgeBase().getFactType( packageName, "TreeToken" ); kSession.getWorkingMemoryEntryPoint( "in_Fld1" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld2" ).insert( -1.0 ); kSession.getWorkingMemoryEntryPoint( "in_Fld3" ).insert( "optA" ); kSession.fireAllRules(); System.err.println( reportWMObjects( kSession ) ); Object token = getToken( kSession ); assertEquals( 0.8, tok.get( token, "confidence" ) ); assertEquals( "null", tok.get( token, "current" ) ); assertEquals( 50.0, tok.get( token, "totalCount" ) ); checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" ); checkFirstDataFieldOfTypeStatus( kSession.getKnowledgeBase().getFactType( packageName, "OutClass" ), true, false, "Missing", "tgtX" ); checkFirstDataFieldOfTypeStatus( kSession.getKnowledgeBase().getFactType( packageName, "OutProb" ), true, false, "Missing", 0.8 ); } }