/* * Copyright 2011 JBoss Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.pmml.pmml_4_1; import org.drools.KnowledgeBase; import org.drools.KnowledgeBaseConfiguration; import org.drools.KnowledgeBaseFactory; import org.drools.RuleBaseConfiguration; import org.drools.builder.*; import org.drools.common.DefaultFactHandle; import org.drools.common.EventFactHandle; import org.drools.conf.EventProcessingOption; import org.drools.definition.type.FactType; import org.drools.io.ResourceFactory; import org.drools.runtime.ClassObjectFilter; import org.drools.runtime.StatefulKnowledgeSession; import org.drools.runtime.rule.FactHandle; import org.drools.runtime.rule.QueryResults; import org.drools.runtime.rule.Variable; import java.io.*; import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.PriorityQueue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public abstract class DroolsAbstractPMMLTest { public static final String PMML = PMML4Compiler.PMML; public static final String BASE_PACK = DroolsAbstractPMMLTest.class.getPackage().getName().replace('.','/'); public static final String RESOURCE_PATH = BASE_PACK; private StatefulKnowledgeSession kSession; private KnowledgeBase kbase; private static PMML4Compiler compiler = new PMML4Compiler(); public DroolsAbstractPMMLTest() { super(); } protected StatefulKnowledgeSession getModelSession(String pmmlSource, boolean verbose) { return getModelSession(new String[] {pmmlSource}, verbose); } protected StatefulKnowledgeSession getModelSession(String[] pmmlSources, boolean verbose) { KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); kbuilder.add(ResourceFactory.newClassPathResource("org/drools/informer/informer-changeset.xml"), ResourceType.CHANGE_SET); if (! verbose) { for ( String pmmlSource : pmmlSources ) { kbuilder.add(ResourceFactory.newClassPathResource(pmmlSource),ResourceType.PMML); } } else { try { for ( String pmmlSource : pmmlSources ) { String src = compiler.compile( ResourceFactory.newClassPathResource( pmmlSource ).getInputStream(), null ); kbuilder.add( ResourceFactory.newByteArrayResource( src.getBytes() ), ResourceType.DRL ); System.out.println(src); } } catch (IOException e) { fail(e.getMessage()); } } KnowledgeBuilderErrors errors = kbuilder.getErrors(); if ( errors.size() > 0 ) { throw new IllegalArgumentException( "Could not parse knowledge : " + errors.toString() ); } KnowledgeBaseConfiguration conf = KnowledgeBaseFactory.newKnowledgeBaseConfiguration(); conf.setOption( EventProcessingOption.STREAM ); //conf.setConflictResolver(LifoConflictResolver.getInstance()); KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase( conf ); kbase.addKnowledgePackages( kbuilder.getKnowledgePackages() ); return kbase.newStatefulKnowledgeSession(); } protected StatefulKnowledgeSession getSession(String theory) { KnowledgeBase kbase = readKnowledgeBase(new ByteArrayInputStream(theory.getBytes())); return kbase != null ? kbase.newStatefulKnowledgeSession() : null; } protected void refreshKSession() { if (getKSession() != null) getKSession().dispose(); setKSession(getKbase().newStatefulKnowledgeSession()); } private static KnowledgeBase readKnowledgeBase(InputStream theory) { return readKnowledgeBase(Arrays.asList(theory)); } private static KnowledgeBase readKnowledgeBase(List<InputStream> theory) { KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); for (InputStream is : theory) kbuilder.add(ResourceFactory.newInputStreamResource(is), ResourceType.DRL); KnowledgeBuilderErrors errors = kbuilder.getErrors(); if (errors.size() > 0) { for (KnowledgeBuilderError error: errors) { System.err.println(error); } throw new IllegalArgumentException("Could not parse knowledge."); } RuleBaseConfiguration conf = new RuleBaseConfiguration(); conf.setEventProcessingMode(EventProcessingOption.STREAM); conf.setAssertBehaviour(RuleBaseConfiguration.AssertBehaviour.EQUALITY); //conf.setConflictResolver(LifoConflictResolver.getInstance()); KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase(conf); kbase.addKnowledgePackages(kbuilder.getKnowledgePackages()); return kbase; } public String reportWMObjects(StatefulKnowledgeSession session) { PriorityQueue<String> queue = new PriorityQueue<String>(); for (FactHandle fh : session.getFactHandles()) { Object o; if (fh instanceof EventFactHandle) { EventFactHandle efh = (EventFactHandle) fh; queue.add("\t " + efh.getStartTimestamp() + "\t" + efh.getObject().toString() + "\n"); } else { o = ((DefaultFactHandle) fh).getObject(); queue.add("\t " + o.toString() + "\n"); } } String ans = " ---------------- WM " + session.getObjects().size() + " --------------\n"; while (! queue.isEmpty()) ans += queue.poll(); ans += " ---------------- END WM -----------\n"; return ans; } private void dump(String s, OutputStream ostream) { // write to outstream Writer writer = null; try { writer = new OutputStreamWriter(ostream, "UTF-8"); writer.write(s); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } finally { try { if (writer != null) { writer.flush(); } } catch (IOException e) { e.printStackTrace(); } } } public StatefulKnowledgeSession getKSession() { return kSession; } public void setKSession(StatefulKnowledgeSession kSession) { this.kSession = kSession; } public KnowledgeBase getKbase() { return kbase; } public void setKbase(KnowledgeBase kbase) { this.kbase = kbase; } protected void checkFirstDataFieldOfTypeStatus(FactType type, boolean valid, boolean missing, String ctx, Object... target) { Class<?> klass = type.getFactClass(); Iterator iter = getKSession().getObjects( new ClassObjectFilter( klass ) ).iterator(); assertTrue( iter.hasNext() ); Object obj = iter.next(); if (ctx == null) { while ( type.get( obj, "context" ) != null && iter.hasNext() ) obj = iter.next(); } else { while ( ( ! ctx.equals( type.get( obj, "context" ) ) ) && iter.hasNext() ) obj = iter.next(); } Object tgt = type.get( obj, "value" ); if ( tgt instanceof Double ) { assert( target[0] instanceof Double ); assertEquals( (Double) target[0], (Double) tgt, 1e-6 ); } else { assertEquals( target[0], tgt ); } assertEquals( valid, type.get( obj, "valid" ) ); assertEquals( missing, type.get( obj, "missing" ) ); } protected double queryDoubleField( String target, String modelName ) { QueryResults results = getKSession().getQueryResults( target, modelName, Variable.v ); assertEquals( 1, results.size() ); return (Double) results.iterator().next().get( "$result" ); } protected double queryIntegerField( String target, String modelName ) { QueryResults results = getKSession().getQueryResults( target, modelName, Variable.v ); assertEquals( 1, results.size() ); return (Integer) results.iterator().next().get( "$result" ); } protected String queryStringField( String target, String modelName ) { QueryResults results = getKSession().getQueryResults( target, modelName, Variable.v ); assertEquals( 1, results.size() ); return (String) results.iterator().next().get( "$result" ); } public Double getDoubleFieldValue( FactType type ) { Class<?> klass = type.getFactClass(); Iterator iter = getKSession().getObjects( new ClassObjectFilter( klass ) ).iterator(); Object obj = iter.next(); return (Double) type.get( obj, "value" ); } public Object getFieldValue( FactType type ) { Class<?> klass = type.getFactClass(); Iterator iter = getKSession().getObjects( new ClassObjectFilter( klass ) ).iterator(); Object obj = iter.next(); return type.get( obj, "value" ); } }