/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.scorecards;
import org.dmg.pmml.pmml_4_2.descr.Attribute;
import org.dmg.pmml.pmml_4_2.descr.Characteristic;
import org.dmg.pmml.pmml_4_2.descr.Characteristics;
import org.dmg.pmml.pmml_4_2.descr.PMML;
import org.dmg.pmml.pmml_4_2.descr.Scorecard;
import org.drools.pmml.pmml_4_2.PMML4Helper;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.KieBase;
import org.kie.api.KieServices;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.KieFileSystem;
import org.kie.api.builder.Results;
import org.kie.api.definition.type.FactType;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.ClassObjectFilter;
import org.kie.api.runtime.KieContainer;
import org.kie.api.runtime.KieSession;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import static org.drools.scorecards.ScorecardCompiler.DrlType.INTERNAL_DECLARED_TYPES;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class ScorecardReasonCodeTest {
@Test
public void testPMMLDocument() throws Exception {
final ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
if (!compileResult) {
assertErrors(scorecardCompiler);
}
Assert.assertNotNull(scorecardCompiler.getPMMLDocument());
String pmml = scorecardCompiler.getPMML();
Assert.assertNotNull(pmml);
assertTrue(pmml.length() > 0);
}
@Test
public void testAbsenceOfReasonCodes() throws Exception {
ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
scorecardCompiler.compileFromExcel( PMMLDocumentTest.class.getResourceAsStream( "/scoremodel_c.xls" ) );
PMML pmml = scorecardCompiler.getPMMLDocument();
for (Object serializable : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()){
if (serializable instanceof Scorecard){
assertFalse(((Scorecard) serializable).getUseReasonCodes());
}
}
}
@Test
public void testUseReasonCodes() throws Exception {
final ScorecardCompiler scorecardCompiler = new ScorecardCompiler( INTERNAL_DECLARED_TYPES );
boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
if (!compileResult) {
assertErrors(scorecardCompiler);
}
final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){
if (serializable instanceof Scorecard){
assertTrue(((Scorecard)serializable).getUseReasonCodes());
assertEquals(100.0, ((Scorecard)serializable).getInitialScore(), 0.0);
assertEquals("pointsBelow",((Scorecard)serializable).getReasonCodeAlgorithm());
}
}
}
@Test
public void testReasonCodes() throws Exception {
final ScorecardCompiler scorecardCompiler = new ScorecardCompiler( INTERNAL_DECLARED_TYPES );
boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
if (!compileResult) {
assertErrors(scorecardCompiler);
}
final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){
if (serializable instanceof Scorecard){
for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){
if (obj instanceof Characteristics){
Characteristics characteristics = (Characteristics)obj;
assertEquals(4, characteristics.getCharacteristics().size());
for (Characteristic characteristic : characteristics.getCharacteristics()){
for (Attribute attribute : characteristic.getAttributes()){
assertNotNull(attribute.getReasonCode());
}
}
return;
}
}
}
}
fail();
}
@Test
public void testBaselineScores() throws Exception {
ScorecardCompiler scorecardCompiler = new ScorecardCompiler( INTERNAL_DECLARED_TYPES );
boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
if (!compileResult) {
assertErrors(scorecardCompiler);
}
final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){
if (serializable instanceof Scorecard){
for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){
if (obj instanceof Characteristics){
Characteristics characteristics = (Characteristics)obj;
assertEquals(4, characteristics.getCharacteristics().size());
assertEquals(10.0, characteristics.getCharacteristics().get(0).getBaselineScore(), 0.0);
assertEquals(99.0, characteristics.getCharacteristics().get(1).getBaselineScore(), 0.0);
assertEquals(12.0, characteristics.getCharacteristics().get(2).getBaselineScore(), 0.0);
assertEquals(15.0, characteristics.getCharacteristics().get(3).getBaselineScore(), 0.0);
assertEquals(25.0, ((Scorecard)serializable).getBaselineScore(), 0.0);
return;
}
}
}
}
fail();
}
@Test
public void testMissingReasonCodes() throws Exception {
ScorecardCompiler scorecardCompiler = new ScorecardCompiler();
scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"), "scorecards_reason_error");
assertEquals(3, scorecardCompiler.getScorecardParseErrors().size());
assertEquals("$F$13", scorecardCompiler.getScorecardParseErrors().get(0).getErrorLocation());
assertEquals("$F$22", scorecardCompiler.getScorecardParseErrors().get(1).getErrorLocation());
}
@Test
public void testMissingBaselineScores() throws Exception {
ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"), "scorecards_reason_error");
assertEquals(3, scorecardCompiler.getScorecardParseErrors().size());
assertEquals("$D$30", scorecardCompiler.getScorecardParseErrors().get(2).getErrorLocation());
}
@Test
public void testReasonCodesCombinations() throws Exception {
KieServices ks = KieServices.Factory.get();
KieFileSystem kfs = ks.newKieFileSystem();
kfs.write( ks.getResources().newClassPathResource( "scoremodel_reasoncodes.xls" )
.setSourcePath( "scoremodel_reasoncodes.xls" )
.setResourceType( ResourceType.SCARD ) );
KieBuilder kieBuilder = ks.newKieBuilder( kfs );
Results res = kieBuilder.buildAll().getResults();
KieContainer kieContainer = ks.newKieContainer( kieBuilder.getKieModule().getReleaseId() );
KieBase kbase = kieContainer.getKieBase();
KieSession session = kbase.newKieSession();
FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" );
FactType scorecardInternalsType = kbase.getFactType( PMML4Helper.pmmlDefaultPackageName(),"ScoreCard" );
FactType scorecardOutputType = kbase.getFactType( "org.drools.scorecards.example","SampleScoreOutput" );
Object scorecard = scorecardType.newInstance();
scorecardType.set(scorecard, "age", 10);
session.insert(scorecard);
session.fireAllRules();
assertEquals( 129.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
Object scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
assertEquals( 129.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
Map reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 2, reasonCodesMap.size() );
assertEquals( 16.0, reasonCodesMap.get( "VL002" ) );
assertEquals( -20.0, reasonCodesMap.get( "AGE02" ) );
Object scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 129.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "VL002", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
session = kbase.newKieSession();
scorecard = scorecardType.newInstance();
scorecardType.set( scorecard, "age", 0 );
scorecardType.set( scorecard, "occupation", "SKYDIVER" );
session.insert( scorecard );
session.fireAllRules();
assertEquals( 99.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
System.out.println( scorecardInternals );
assertEquals( 99.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 3, reasonCodesMap.size() );
assertEquals( 109.0, reasonCodesMap.get( "OCC01" ) );
assertEquals( 16.0, reasonCodesMap.get( "VL002" ) );
assertEquals( 0.0, reasonCodesMap.get( "AGE01" ) );
scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 99.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "OCC01", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
session = kbase.newKieSession();
scorecard = scorecardType.newInstance();
scorecardType.set( scorecard, "age", 20 );
scorecardType.set( scorecard, "occupation", "TEACHER" );
scorecardType.set( scorecard, "residenceState", "AP" );
scorecardType.set( scorecard, "validLicense", true );
session.insert( scorecard );
session.fireAllRules();
assertEquals( 141.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
System.out.println( scorecardInternals );
assertEquals( 141.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 4, reasonCodesMap.size() );
assertEquals( 89.0, reasonCodesMap.get( "OCC02" ) );
assertEquals( 22.0, reasonCodesMap.get( "RS001" ) );
assertEquals( 14.0, reasonCodesMap.get( "VL001" ) );
assertEquals( -30.0, reasonCodesMap.get( "AGE03" ) );
scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 141.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "OCC02", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
}
@Test
public void testPointsAbove() throws Exception {
ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
scorecardCompiler.compileFromExcel( PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"), "scorecards_pointsAbove" );
assertEquals( 0, scorecardCompiler.getScorecardParseErrors().size() );
String drl = scorecardCompiler.getDRL();
assertNotNull(drl);
KieServices ks = KieServices.Factory.get();
KieFileSystem kfs = ks.newKieFileSystem();
kfs.write( ks.getResources().newByteArrayResource( drl.getBytes() )
.setSourcePath( "scoremodel_pointsAbove.drl" )
.setResourceType( ResourceType.DRL ) );
KieBuilder kieBuilder = ks.newKieBuilder( kfs );
Results res = kieBuilder.buildAll().getResults();
KieContainer kieContainer = ks.newKieContainer( kieBuilder.getKieModule().getReleaseId() );
KieBase kbase = kieContainer.getKieBase();
KieSession session = kbase.newKieSession();
FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" );
FactType scorecardInternalsType = kbase.getFactType( PMML4Helper.pmmlDefaultPackageName(),"ScoreCard" );
FactType scorecardOutputType = kbase.getFactType( "org.drools.scorecards.example","SampleScoreOutput" );
Object scorecard = scorecardType.newInstance();
scorecardType.set(scorecard, "age", 10);
session.insert(scorecard);
session.fireAllRules();
assertEquals( 29.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
Object scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
assertEquals( 29.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
Map reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 2, reasonCodesMap.size() );
assertEquals( -16.0, reasonCodesMap.get( "VL002" ) );
assertEquals( 20.0, reasonCodesMap.get( "AGE02" ) );
Object scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 29.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "AGE02", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
session = kbase.newKieSession();
scorecard = scorecardType.newInstance();
scorecardType.set( scorecard, "age", 0 );
scorecardType.set( scorecard, "occupation", "SKYDIVER" );
session.insert( scorecard );
session.fireAllRules();
assertEquals( -1.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
System.out.println( scorecardInternals );
assertEquals( -1.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 3, reasonCodesMap.size() );
assertEquals( -109.0, reasonCodesMap.get( "OCC01" ) );
assertEquals( -16.0, reasonCodesMap.get( "VL002" ) );
assertEquals( 0.0, reasonCodesMap.get( "AGE01" ) );
assertEquals( Arrays.asList( "AGE01", "VL002", "OCC01" ), new ArrayList( reasonCodesMap.keySet() ) );
scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( -1.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "AGE01", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
session = kbase.newKieSession();
scorecard = scorecardType.newInstance();
scorecardType.set( scorecard, "age", 20 );
scorecardType.set( scorecard, "occupation", "TEACHER" );
scorecardType.set( scorecard, "residenceState", "AP" );
scorecardType.set( scorecard, "validLicense", true );
session.insert( scorecard );
session.fireAllRules();
assertEquals( 41.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
System.out.println( scorecardInternals );
assertEquals( 41.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 4, reasonCodesMap.size() );
assertEquals( -89.0, reasonCodesMap.get( "OCC02" ) );
assertEquals( -22.0, reasonCodesMap.get( "RS001" ) );
assertEquals( -14.0, reasonCodesMap.get( "VL001" ) );
assertEquals( 30.0, reasonCodesMap.get( "AGE03" ) );
assertEquals( Arrays.asList( "AGE03", "VL001", "RS001", "OCC02" ), new ArrayList( reasonCodesMap.keySet() ) );
scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 41.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "AGE03", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
}
@Test
public void testPointsBelow() throws Exception {
ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"), "scorecards_pointsBelow");
assertEquals(0, scorecardCompiler.getScorecardParseErrors().size());
String drl = scorecardCompiler.getDRL();
KieServices ks = KieServices.Factory.get();
KieFileSystem kfs = ks.newKieFileSystem();
kfs.write( ks.getResources().newByteArrayResource( drl.getBytes() )
.setSourcePath( "scoremodel_pointsAbove.drl" )
.setResourceType( ResourceType.DRL ) );
KieBuilder kieBuilder = ks.newKieBuilder( kfs );
Results res = kieBuilder.buildAll().getResults();
KieContainer kieContainer = ks.newKieContainer( kieBuilder.getKieModule().getReleaseId() );
KieBase kbase = kieContainer.getKieBase();
KieSession session = kbase.newKieSession();
FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" );
FactType scorecardInternalsType = kbase.getFactType( PMML4Helper.pmmlDefaultPackageName(),"ScoreCard" );
FactType scorecardOutputType = kbase.getFactType( "org.drools.scorecards.example","SampleScoreOutput" );
Object scorecard = scorecardType.newInstance();
scorecardType.set(scorecard, "age", 10);
session.insert(scorecard);
session.fireAllRules();
assertEquals( 29.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
Object scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
assertEquals( 29.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
Map reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 2, reasonCodesMap.size() );
assertEquals( 16.0, reasonCodesMap.get( "VL002" ) );
assertEquals( -20.0, reasonCodesMap.get( "AGE02" ) );
Object scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 29.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "VL002", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
session = kbase.newKieSession();
scorecard = scorecardType.newInstance();
scorecardType.set( scorecard, "age", 0 );
scorecardType.set( scorecard, "occupation", "SKYDIVER" );
session.insert( scorecard );
session.fireAllRules();
assertEquals( -1.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
System.out.println( scorecardInternals );
assertEquals( -1.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 3, reasonCodesMap.size() );
assertEquals( 109.0, reasonCodesMap.get( "OCC01" ) );
assertEquals( 16.0, reasonCodesMap.get( "VL002" ) );
assertEquals( 0.0, reasonCodesMap.get( "AGE01" ) );
scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( -1.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "OCC01", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
session = kbase.newKieSession();
scorecard = scorecardType.newInstance();
scorecardType.set( scorecard, "age", 20 );
scorecardType.set( scorecard, "occupation", "TEACHER" );
scorecardType.set( scorecard, "residenceState", "AP" );
scorecardType.set( scorecard, "validLicense", true );
session.insert( scorecard );
session.fireAllRules();
assertEquals( 41.0, scorecardType.get( scorecard, "scorecard__calculatedScore" ) );
scorecardInternals = session.getObjects( new ClassObjectFilter( scorecardInternalsType.getFactClass() ) ).iterator().next();
System.out.println( scorecardInternals );
assertEquals( 41.0, scorecardInternalsType.get( scorecardInternals, "score" ) );
reasonCodesMap = (Map) scorecardInternalsType.get( scorecardInternals, "ranking" );
assertNotNull( reasonCodesMap );
assertEquals( 4, reasonCodesMap.size() );
assertEquals( 89.0, reasonCodesMap.get( "OCC02" ) );
assertEquals( 22.0, reasonCodesMap.get( "RS001" ) );
assertEquals( 14.0, reasonCodesMap.get( "VL001" ) );
assertEquals( -30.0, reasonCodesMap.get( "AGE03" ) );
scorecardOutput = session.getObjects( new ClassObjectFilter( scorecardOutputType.getFactClass() ) ).iterator().next();
assertEquals( 41.0, scorecardOutputType.get( scorecardOutput, "calculatedScore" ) );
assertEquals( "OCC02", scorecardOutputType.get( scorecardOutput, "reasonCode" ) );
session.dispose();
}
private void assertErrors(final ScorecardCompiler compiler) {
final StringBuilder errorBuilder = new StringBuilder();
compiler.getScorecardParseErrors().forEach((error) -> errorBuilder.append(error.getErrorLocation() + " -> " + error.getErrorMessage() + "\n"));
final String errors = errorBuilder.toString();
Assert.fail("There are compile errors: \n" + errors);
}
}