/*
* Copyright 2012 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.scorecards.pmml;
import org.dmg.pmml.pmml_4_2.descr.*;
import org.drools.core.util.StringUtils;
import org.drools.pmml.pmml_4_2.extensions.PMMLExtensionNames;
import org.drools.pmml.pmml_4_2.extensions.PMMLIOAdapterMode;
import org.drools.scorecards.StringUtil;
import org.drools.scorecards.parser.xls.XLSKeywords;
import java.math.BigInteger;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Locale;
public class ScorecardPMMLGenerator {
public PMML generateDocument(Scorecard pmmlScorecard) {
//first clean up the scorecard
removeEmptyExtensions(pmmlScorecard);
createAndSetPredicates(pmmlScorecard);
//second add additional elements to scorecard
createAndSetOutput(pmmlScorecard);
repositionExternalClassExtensions(pmmlScorecard);
Extension scorecardPackage = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.MODEL_PACKAGE );
if ( scorecardPackage != null) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(scorecardPackage);
}
Extension importsExt = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.MODEL_IMPORTS );
if ( importsExt != null) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(importsExt);
}
Extension agendaGroupExt = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.AGENDA_GROUP );
if ( agendaGroupExt != null) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(agendaGroupExt);
}
Extension ruleFlowGroupExt = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.RULEFLOW_GROUP);
if ( ruleFlowGroupExt != null) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(ruleFlowGroupExt);
}
//now create the PMML document
PMML pmml = new PMML();
Header header = new Header();
Timestamp timestamp = new Timestamp();
timestamp.getContent().add(new SimpleDateFormat("yyyy.MM.dd 'at' HH:mm:ss z", Locale.ENGLISH).format(new Date()));
header.setTimestamp(timestamp);
header.setDescription("generated by the drools-scorecards module");
header.getExtensions().add(scorecardPackage);
header.getExtensions().add(importsExt);
if (ruleFlowGroupExt != null){
header.getExtensions().add(ruleFlowGroupExt);
}
if (agendaGroupExt != null){
header.getExtensions().add(agendaGroupExt);
}
pmml.setHeader(header);
createAndSetDataDictionary(pmml, pmmlScorecard);
pmml.getAssociationModelsAndBaselineModelsAndClusteringModels().add(pmmlScorecard);
removeAttributeFieldExtension(pmmlScorecard);
return pmml;
}
private void repositionExternalClassExtensions(Scorecard pmmlScorecard) {
Characteristics characteristics = null;
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if ( obj instanceof Characteristics ) {
characteristics = (Characteristics) obj;
break;
}
}
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if ( obj instanceof MiningSchema ) {
MiningSchema schema = (MiningSchema)obj;
Extension adapter = new Extension();
adapter.setName( PMMLExtensionNames.IO_ADAPTER );
adapter.setValue( PMMLIOAdapterMode.BEAN.name() );
schema.getExtensions().add( adapter );
for (MiningField miningField : schema.getMiningFields()) {
String fieldName = miningField.getName();
for (Characteristic characteristic : characteristics.getCharacteristics()){
String characteristicName = ScorecardPMMLUtils.extractFieldNameFromCharacteristic(characteristic);
if (fieldName.equalsIgnoreCase(characteristicName)){
Extension extension = ScorecardPMMLUtils.getExtension(characteristic.getExtensions(), PMMLExtensionNames.EXTERNAL_CLASS );
if ( extension != null ) {
characteristic.getExtensions().remove(extension);
if ( ScorecardPMMLUtils.getExtension(miningField.getExtensions(), PMMLExtensionNames.EXTERNAL_CLASS ) == null ) {
miningField.getExtensions().add(extension);
}
}
}
}
}
MiningField targetField = new MiningField();
targetField.setName( ScorecardPMMLExtensionNames.DEFAULT_PREDICTED_FIELD );
targetField.setUsageType( FIELDUSAGETYPE.PREDICTED );
schema.getMiningFields().add( targetField );
} else if ( obj instanceof Output ) {
Extension adapter = new Extension();
adapter.setName( PMMLExtensionNames.IO_ADAPTER );
adapter.setValue( PMMLIOAdapterMode.BEAN.name() );
( (Output) obj ).getExtensions().add( adapter );
}
}
}
private void removeAttributeFieldExtension(Scorecard pmmlScorecard) {
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Characteristics) {
Characteristics characteristics = (Characteristics) obj;
for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
for (Attribute attribute : characteristic.getAttributes()) {
Extension fieldExtension = ScorecardPMMLUtils.getExtension(attribute.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
if ( fieldExtension != null ) {
attribute.getExtensions().remove(fieldExtension);
//break;
}
}
}
}
}
}
private void createAndSetDataDictionary(PMML pmml, Scorecard pmmlScorecard) {
DataDictionary dataDictionary = new DataDictionary();
pmml.setDataDictionary(dataDictionary);
int ctr = 0;
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Characteristics) {
Characteristics characteristics = (Characteristics) obj;
for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
DataField dataField = new DataField();
Extension dataTypeExtension = ScorecardPMMLUtils.getExtension(characteristic.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
String dataType = dataTypeExtension.getValue();
String factType = ScorecardPMMLUtils.getExtensionValue(characteristic.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_FACTTYPE);
if ( factType != null ){
Extension extension = new Extension();
extension.setName("FactType");
extension.setValue(factType);
dataField.getExtensions().add(extension);
}
if (XLSKeywords.DATATYPE_NUMBER.equalsIgnoreCase(dataType)) {
dataField.setDataType(DATATYPE.DOUBLE);
dataField.setOptype(OPTYPE.CONTINUOUS);
} else if (XLSKeywords.DATATYPE_TEXT.equalsIgnoreCase(dataType)) {
dataField.setDataType(DATATYPE.STRING);
dataField.setOptype(OPTYPE.CATEGORICAL);
} else if (XLSKeywords.DATATYPE_BOOLEAN.equalsIgnoreCase(dataType)) {
dataField.setDataType(DATATYPE.BOOLEAN);
dataField.setOptype(OPTYPE.CATEGORICAL);
}
String field = "";
for (Attribute attribute : characteristic.getAttributes()) {
for (Extension extension : attribute.getExtensions()) {
if ( ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD.equalsIgnoreCase(extension.getName())) {
field = extension.getValue();
break;
}//
}
}
dataField.setName(field);
dataDictionary.getDataFields().add(dataField);
characteristic.getExtensions().remove(dataTypeExtension);
ctr++;
}
}
}
DataField targetField = new DataField();
targetField.setName( ScorecardPMMLExtensionNames.DEFAULT_PREDICTED_FIELD );
targetField.setDataType( DATATYPE.DOUBLE );
targetField.setOptype( OPTYPE.CONTINUOUS );
dataDictionary.getDataFields().add( targetField );
dataDictionary.setNumberOfFields(BigInteger.valueOf(ctr + 1));
}
private void createAndSetOutput(Scorecard pmmlScorecard) {
Extension classExtension = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.EXTERNAL_CLASS);
Extension fieldExtension = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_SCORE_FIELD);
Extension reasonCodeExtension = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_REASONCODES_FIELD);
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Output) {
Output output = (Output)obj;
OutputField outputField = new OutputField();
outputField.setDataType(DATATYPE.DOUBLE);
outputField.setFeature(RESULTFEATURE.PREDICTED_VALUE);
outputField.setDisplayName("Final Score");
if ( fieldExtension != null ) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(fieldExtension);
outputField.setName(fieldExtension.getValue());
} else {
outputField.setName( "calculatedScore" );
}
if ( classExtension != null ) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(classExtension);
outputField.getExtensions().add( classExtension );
}
output.getOutputFields().add(outputField);
if ( pmmlScorecard.getUseReasonCodes() ) {
OutputField reasonCodeField = new OutputField();
reasonCodeField.setDataType( DATATYPE.STRING );
reasonCodeField.setFeature( RESULTFEATURE.REASON_CODE );
reasonCodeField.setDisplayName( "Principal Reason Code" );
if ( reasonCodeExtension != null ) {
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(reasonCodeExtension);
reasonCodeField.getExtensions().add( classExtension );
reasonCodeField.setName( reasonCodeExtension.getValue() );
} else {
reasonCodeField.setName( "reasonCode" );
}
output.getOutputFields().add( reasonCodeField );
}
break;
}
}
}
private void createAndSetPredicates(Scorecard pmmlScorecard) {
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Characteristics) {
Characteristics characteristics = (Characteristics) obj;
for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
String dataType = ScorecardPMMLUtils.getExtensionValue(characteristic.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
Extension predicateExtension = null;
for (Attribute attribute : characteristic.getAttributes()) {
String predicateAsString = "";
String field = ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
for (Extension extension : attribute.getExtensions()) {
if ("predicateResolver".equalsIgnoreCase(extension.getName())) {
predicateAsString = extension.getValue();
predicateExtension = extension;
break;
}
}
setPredicatesForAttribute(attribute, dataType, field, predicateAsString);
attribute.getExtensions().remove(predicateExtension);
}
}
}
}
}
private void setPredicatesForAttribute(Attribute pmmlAttribute, String dataType, String field, String predicateAsString) {
predicateAsString = StringUtil.unescapeXML(predicateAsString);
if (XLSKeywords.DATATYPE_NUMBER.equalsIgnoreCase(dataType)) {
setNumericPredicate(pmmlAttribute, field, predicateAsString);
} else if (XLSKeywords.DATATYPE_TEXT.equalsIgnoreCase(dataType)) {
setTextPredicate(pmmlAttribute, field, predicateAsString);
} else if (XLSKeywords.DATATYPE_BOOLEAN.equalsIgnoreCase(dataType)) {
setBooleanPredicate(pmmlAttribute, field, predicateAsString);
}
}
private void setBooleanPredicate(Attribute pmmlAttribute, String field, String predicateAsString) {
SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(field);
simplePredicate.setOperator(PMMLOperators.EQUAL);
if ("TRUE".equalsIgnoreCase(predicateAsString)){
simplePredicate.setValue("TRUE");
} else if ("FALSE".equalsIgnoreCase(predicateAsString)){
simplePredicate.setValue("FALSE");
}
pmmlAttribute.setSimplePredicate(simplePredicate);
}
private void setTextPredicate(Attribute pmmlAttribute, String field, String predicateAsString) {
String operator = "";
if (predicateAsString.startsWith("=")) {
operator = "=";
predicateAsString = predicateAsString.substring(1);
} else if (predicateAsString.startsWith("!=")) {
operator = "!=";
predicateAsString = predicateAsString.substring(2);
}
if (predicateAsString.contains(",")) {
SimpleSetPredicate simpleSetPredicate = new SimpleSetPredicate();
if ("!=".equalsIgnoreCase(operator)) {
simpleSetPredicate.setBooleanOperator(PMMLOperators.IS_NOT_IN);
} else {
simpleSetPredicate.setBooleanOperator(PMMLOperators.IS_IN);
}
simpleSetPredicate.setField(field);
predicateAsString = predicateAsString.trim();
if (predicateAsString.endsWith(",")) {
predicateAsString = predicateAsString.substring(0, predicateAsString.length()-1);
}
Array array = new Array();
array.setContent(predicateAsString.replace(",", " "));
array.setType("string");
array.setN(BigInteger.valueOf(predicateAsString.split(",").length));
simpleSetPredicate.setArray(array);
pmmlAttribute.setSimpleSetPredicate(simpleSetPredicate);
} else {
SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(field);
if ("!=".equalsIgnoreCase(operator)) {
simplePredicate.setOperator(PMMLOperators.NOT_EQUAL);
} else {
simplePredicate.setOperator(PMMLOperators.EQUAL);
}
simplePredicate.setValue(predicateAsString);
pmmlAttribute.setSimplePredicate(simplePredicate);
}
}
private void setNumericPredicate(Attribute pmmlAttribute, String field, String predicateAsString) {
if (predicateAsString.indexOf("-") > 0) {
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator("and");
String left = predicateAsString.substring(0, predicateAsString.indexOf("-")).trim();
String right = predicateAsString.substring(predicateAsString.indexOf("-") + 1).trim();
SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(field);
simplePredicate.setOperator(PMMLOperators.GREATER_OR_EQUAL);
simplePredicate.setValue(left);
compoundPredicate.getSimplePredicatesAndCompoundPredicatesAndSimpleSetPredicates().add(simplePredicate);
simplePredicate = new SimplePredicate();
simplePredicate.setField(field);
simplePredicate.setOperator(PMMLOperators.LESS_THAN);
simplePredicate.setValue(right);
compoundPredicate.getSimplePredicatesAndCompoundPredicatesAndSimpleSetPredicates().add(simplePredicate);
pmmlAttribute.setCompoundPredicate(compoundPredicate);
} else {
SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(field);
if (predicateAsString.startsWith("<=")) {
simplePredicate.setOperator(PMMLOperators.LESS_OR_EQUAL);
simplePredicate.setValue(predicateAsString.substring(3).trim());
} else if (predicateAsString.startsWith(">=")) {
simplePredicate.setOperator(PMMLOperators.GREATER_OR_EQUAL);
simplePredicate.setValue(predicateAsString.substring(3).trim());
} else if (predicateAsString.startsWith("=")) {
simplePredicate.setOperator(PMMLOperators.EQUAL);
simplePredicate.setValue(predicateAsString.substring(2).trim());
} else if (predicateAsString.startsWith("!=")) {
simplePredicate.setOperator(PMMLOperators.NOT_EQUAL);
simplePredicate.setValue(predicateAsString.substring(3).trim());
} else if (predicateAsString.startsWith("<")) {
simplePredicate.setOperator(PMMLOperators.LESS_THAN);
simplePredicate.setValue(predicateAsString.substring(2).trim());
} else if (predicateAsString.startsWith(">")) {
simplePredicate.setOperator(PMMLOperators.GREATER_THAN);
simplePredicate.setValue(predicateAsString.substring(2).trim());
}
pmmlAttribute.setSimplePredicate(simplePredicate);
}
}
private void removeEmptyExtensions(Scorecard pmmlScorecard) {
for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Characteristics) {
Characteristics characteristics = (Characteristics) obj;
for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
List<Extension> toRemoveExtensionsList = new ArrayList<Extension>();
for (Extension extension : characteristic.getExtensions()) {
if (StringUtils.isEmpty(extension.getValue())) {
toRemoveExtensionsList.add(extension);
}
}
for (Extension extension : toRemoveExtensionsList) {
characteristic.getExtensions().remove(extension);
}
for (Attribute attribute : characteristic.getAttributes()) {
List<Extension> toRemoveExtensionsList2 = new ArrayList<Extension>();
for (Extension extension : attribute.getExtensions()) {
if (StringUtils.isEmpty(extension.getValue())) {
toRemoveExtensionsList2.add(extension);
}
}
for (Extension extension : toRemoveExtensionsList2) {
attribute.getExtensions().remove(extension);
}
}
}
}
}
}
}