/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.scorecards; import org.junit.Assert; import org.dmg.pmml.pmml_4_2.descr.*; import org.drools.pmml.pmml_4_2.extensions.PMMLExtensionNames; import org.drools.scorecards.pmml.ScorecardPMMLUtils; import org.junit.Before; import org.junit.Test; import static org.junit.Assert.*; import static org.drools.scorecards.ScorecardCompiler.DrlType.INTERNAL_DECLARED_TYPES; public class PMMLDocumentTest { private static PMML pmmlDocument; private static ScorecardCompiler scorecardCompiler; @Before public void setUp() throws Exception { scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES); scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_c.xls")); pmmlDocument = scorecardCompiler.getPMMLDocument(); } @Test public void testPMMLDocument() throws Exception { Assert.assertNotNull(pmmlDocument); String pmml = scorecardCompiler.getPMML(); Assert.assertNotNull(pmml); Assert.assertTrue(pmml.length() > 0); } @Test public void testHeader() throws Exception { Header header = pmmlDocument.getHeader(); assertNotNull(header); assertNotNull(ScorecardPMMLUtils.getExtensionValue(header.getExtensions(), PMMLExtensionNames.MODEL_PACKAGE)); assertNotNull(ScorecardPMMLUtils.getExtensionValue(header.getExtensions(), PMMLExtensionNames.MODEL_IMPORTS)); } @Test public void testDataDictionary() throws Exception { DataDictionary dataDictionary = pmmlDocument.getDataDictionary(); assertNotNull(dataDictionary); assertEquals(5, dataDictionary.getNumberOfFields().intValue()); assertEquals("age", dataDictionary.getDataFields().get(0).getName()); assertEquals("occupation",dataDictionary.getDataFields().get(1).getName()); assertEquals("residenceState", dataDictionary.getDataFields().get(2).getName()); assertEquals("validLicense", dataDictionary.getDataFields().get(3).getName()); } @Test public void testMiningSchema() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){ if (obj instanceof MiningSchema){ MiningSchema miningSchema = ((MiningSchema)obj); assertEquals(5, miningSchema.getMiningFields().size()); assertEquals("age", miningSchema.getMiningFields().get(0).getName()); assertEquals("occupation",miningSchema.getMiningFields().get(1).getName()); assertEquals("residenceState", miningSchema.getMiningFields().get(2).getName()); assertEquals("validLicense", miningSchema.getMiningFields().get(3).getName()); return; } } } } fail(); } @Test public void testCharacteristicsAndAttributes() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){ if (obj instanceof Characteristics){ Characteristics characteristics = (Characteristics)obj; assertEquals(4, characteristics.getCharacteristics().size()); assertEquals("AgeScore", characteristics.getCharacteristics().get(0).getName()); assertEquals("$B$8", ScorecardPMMLUtils.getExtensionValue(characteristics.getCharacteristics().get(0).getExtensions(), "cellRef")); assertEquals("OccupationScore",characteristics.getCharacteristics().get(1).getName()); assertEquals("$B$16", ScorecardPMMLUtils.getExtensionValue(characteristics.getCharacteristics().get(1).getExtensions(), "cellRef")); assertEquals("ResidenceStateScore",characteristics.getCharacteristics().get(2).getName()); assertEquals("$B$22", ScorecardPMMLUtils.getExtensionValue(characteristics.getCharacteristics().get(2).getExtensions(), "cellRef")); assertEquals("ValidLicenseScore",characteristics.getCharacteristics().get(3).getName()); assertEquals("$B$28", ScorecardPMMLUtils.getExtensionValue(characteristics.getCharacteristics().get(3).getExtensions(), "cellRef")); return; } } } } fail(); } @Test public void testAgeScoreCharacteristic() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){ if (obj instanceof Characteristics){ Characteristics characteristics = (Characteristics)obj; assertEquals(4, characteristics.getCharacteristics().size()); assertEquals("AgeScore", characteristics.getCharacteristics().get(0).getName()); assertEquals("$B$8", ScorecardPMMLUtils.getExtensionValue(characteristics.getCharacteristics().get(0).getExtensions(), "cellRef")); assertNotNull(characteristics.getCharacteristics().get(0).getAttributes()); assertEquals(4, characteristics.getCharacteristics().get(0).getAttributes().size()); Attribute attribute = characteristics.getCharacteristics().get(0).getAttributes().get(0); assertEquals("$C$10", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); attribute = characteristics.getCharacteristics().get(0).getAttributes().get(1); assertEquals("$C$11", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getCompoundPredicate()); attribute = characteristics.getCharacteristics().get(0).getAttributes().get(2); assertEquals("$C$12", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getCompoundPredicate()); attribute = characteristics.getCharacteristics().get(0).getAttributes().get(3); assertEquals("$C$13", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); return; } } } } fail(); } @Test public void testOccupationScoreCharacteristic() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){ if (obj instanceof Characteristics){ Characteristics characteristics = (Characteristics)obj; assertEquals(4, characteristics.getCharacteristics().size()); assertNotNull(characteristics.getCharacteristics().get(1).getAttributes()); assertEquals(3, characteristics.getCharacteristics().get(1).getAttributes().size()); Attribute attribute = characteristics.getCharacteristics().get(1).getAttributes().get(0); assertEquals("$C$18", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "description")); assertEquals("skydiving is a risky occupation", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "description")); assertNotNull(attribute.getSimplePredicate()); attribute = characteristics.getCharacteristics().get(1).getAttributes().get(1); assertEquals("$C$19", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimpleSetPredicate()); attribute = characteristics.getCharacteristics().get(1).getAttributes().get(2); assertEquals("$C$20", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); return; } } } } fail(); } @Test public void testResidenceStateScoreCharacteristic() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){ if (obj instanceof Characteristics){ Characteristics characteristics = (Characteristics)obj; assertEquals(4, characteristics.getCharacteristics().size()); assertNotNull(characteristics.getCharacteristics().get(2).getAttributes()); assertEquals(3, characteristics.getCharacteristics().get(2).getAttributes().size()); Attribute attribute = characteristics.getCharacteristics().get(2).getAttributes().get(0); assertEquals("$C$24", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); attribute = characteristics.getCharacteristics().get(2).getAttributes().get(1); assertEquals("$C$25", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); attribute = characteristics.getCharacteristics().get(2).getAttributes().get(2); assertEquals("$C$26", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); return; } } } } fail(); } @Test public void testValidLicenseScoreCharacteristic() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ for (Object obj :((Scorecard)serializable) .getExtensionsAndCharacteristicsAndMiningSchemas()){ if (obj instanceof Characteristics){ Characteristics characteristics = (Characteristics)obj; assertEquals(4, characteristics.getCharacteristics().size()); assertNotNull(characteristics.getCharacteristics().get(3).getAttributes()); assertEquals(2, characteristics.getCharacteristics().get(3).getAttributes().size()); Attribute attribute = characteristics.getCharacteristics().get(3).getAttributes().get(0); assertEquals("$C$30", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); attribute = characteristics.getCharacteristics().get(3).getAttributes().get(1); assertEquals("$C$31", ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), "cellRef")); assertNotNull(attribute.getSimplePredicate()); return; } } } } fail(); } @Test public void testScorecardWithExtensions() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ Scorecard scorecard = (Scorecard)serializable; assertEquals("Sample Score",scorecard.getModelName()); // assertNotNull(ScorecardPMMLUtils.getExtension(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_OBJECT_CLASS)); // assertNotNull(ScorecardPMMLUtils.getExtension(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_BOUND_VAR_NAME)); return; } } fail(); } @Test public void testOutput() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ Scorecard scorecard = (Scorecard)serializable; for (Object obj :scorecard.getExtensionsAndCharacteristicsAndMiningSchemas()){ if ( obj instanceof Output) { Output output = (Output)obj; assertEquals(1, output.getOutputFields().size()); assertNotNull(output.getOutputFields().get(0)); assertEquals("calculatedScore", output.getOutputFields().get(0).getName()); assertEquals("Final Score", output.getOutputFields().get(0).getDisplayName()); assertEquals("double", output.getOutputFields().get(0).getDataType().value()); assertEquals("predictedValue", output.getOutputFields().get(0).getFeature().value()); return; } } } } fail(); } }