/*
* Copyright 2011 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.pmml.pmml_4_1;
import org.drools.KnowledgeBase;
import org.drools.KnowledgeBaseFactory;
import org.drools.RuleBaseConfiguration;
import org.drools.builder.*;
import org.drools.compiler.PackageRegistry;
import org.drools.conf.EventProcessingOption;
import org.drools.io.Resource;
import org.drools.io.ResourceFactory;
import org.drools.io.impl.ClassPathResource;
import org.drools.runtime.StatefulKnowledgeSession;
import org.mvel2.templates.SimpleTemplateRegistry;
import org.mvel2.templates.TemplateCompiler;
import org.mvel2.templates.TemplateRegistry;
import org.w3c.dom.Element;
import javax.xml.XMLConstants;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import java.io.*;
import java.util.*;
import org.dmg.pmml.pmml_4_1.descr.*;
import org.xml.sax.SAXException;
public class PMML4Compiler implements org.drools.compiler.PMMLCompiler {
public static final String PMML = "org.dmg.pmml.pmml_4_1.descr";
public static final String SCHEMA_PATH = "xsd/org/dmg/pmml/pmml_4_1/pmml-4-1.xsd";
public static final String BASE_PACK = PMML4Compiler.class.getPackage().getName().replace('.','/');
public static final String BASE_TEST_PACK = "org.drools.pmml.pmml_4_1.test";
public static final String VISITOR_RULES = BASE_PACK + "/pmml_visitor.drl";
public static boolean visitorRules = false;
public static final String COMPILER_RULES = BASE_PACK + "/pmml_compiler.drl";
public static boolean compilerRules = false;
public static final String INFORMER_RULES = BASE_PACK + "/pmml_informer.drl";
public static boolean informerRules = false;
protected static boolean globalLoaded = false;
protected static final String[] GLOBAL_TEMPLATES = new String[] {
"global/pmml_header.drlt",
"global/pmml_import.drlt",
"global/modelMark.drlt",
"global/dataDefinition/common.drlt",
"global/dataDefinition/rootDataField.drlt",
"global/dataDefinition/inputBinding.drlt",
"global/dataDefinition/outputBinding.drlt",
"global/dataDefinition/ioTypeDeclare.drlt",
"global/dataDefinition/updateIOField.drlt",
"global/dataDefinition/inputFromEP.drlt",
"global/dataDefinition/ioTrait.drlt",
"global/manipulation/confirm.drlt",
"global/manipulation/mapMissingValues.drlt",
"global/manipulation/propagateMissingValues.drlt",
"global/validation/intervalsOnDomainRestriction.drlt",
"global/validation/valuesOnDomainRestriction.drlt",
"global/validation/valuesOnDomainRestrictionMissing.drlt",
"global/validation/valuesOnDomainRestrictionInvalid.drlt",
};
protected static boolean transformationLoaded = false;
protected static final String[] TRANSFORMATION_TEMPLATES = new String[] {
"transformations/normContinuous/boundedLowerOutliers.drlt",
"transformations/normContinuous/boundedUpperOutliers.drlt",
"transformations/normContinuous/normContOutliersAsMissing.drlt",
"transformations/normContinuous/linearTractNormalization.drlt",
"transformations/normContinuous/lowerExtrapolateLinearTractNormalization.drlt",
"transformations/normContinuous/upperExtrapolateLinearTractNormalization.drlt",
"transformations/aggregate/aggregate.drlt",
"transformations/aggregate/collect.drlt",
"transformations/simple/constantField.drlt",
"transformations/simple/aliasedField.drlt",
"transformations/normDiscrete/indicatorFieldYes.drlt",
"transformations/normDiscrete/indicatorFieldNo.drlt",
"transformations/normDiscrete/predicateField.drlt",
"transformations/discretize/intervalBinning.drlt",
"transformations/discretize/outOfBinningDefault.drlt",
"transformations/discretize/outOfBinningMissing.drlt",
"transformations/mapping/mapping.drlt",
"transformations/functions/apply.drlt",
"transformations/functions/function.drlt"
};
protected static boolean miningLoaded = false;
protected static final String[] MINING_TEMPLATES = new String[] {
"models/common/mining/miningField.drlt",
"models/common/mining/miningFieldInvalid.drlt",
"models/common/mining/miningFieldMissing.drlt",
"models/common/mining/miningFieldOutlierAsMissing.drlt",
"models/common/mining/miningFieldOutlierAsExtremeLow.drlt",
"models/common/mining/miningFieldOutlierAsExtremeUpp.drlt",
"models/common/target/targetReshape.drlt",
"models/common/target/aliasedOutput.drlt",
"models/common/target/addOutputFeature.drlt",
"models/common/target/addRelOutputFeature.drlt",
"models/common/target/outputQuery.drlt",
"models/common/target/outputQueryPredicate.drlt"
};
protected static boolean neuralLoaded = false;
protected static final String[] NEURAL_TEMPLATES = new String[] {
"models/neural/neuralBeans.drlt",
"models/neural/neuralWireInput.drlt",
"models/neural/neuralBuildSynapses.drlt",
"models/neural/neuralBuildNeurons.drlt",
"models/neural/neuralLinkSynapses.drlt",
"models/neural/neuralFire.drlt",
"models/neural/neuralLayerMaxNormalization.drlt",
"models/neural/neuralLayerSoftMaxNormalization.drlt",
"models/neural/neuralOutputField.drlt",
"models/neural/neuralClean.drlt"
};
protected static boolean svmLoaded = false;
protected static final String[] SVM_TEMPLATES = new String[] {
"models/svm/svmParams.drlt",
"models/svm/svmDeclare.drlt",
"models/svm/svmFunctions.drlt",
"models/svm/svmBuild.drlt",
"models/svm/svmInitSupportVector.drlt",
"models/svm/svmInitInputVector.drlt",
"models/svm/svmKernelEval.drlt",
"models/svm/svmOutputGeneration.drlt",
"models/svm/svmOutputVoteDeclare.drlt",
"models/svm/svmOutputVote1vN.drlt",
"models/svm/svmOutputVote1v1.drlt",
};
protected static boolean simpleRegLoaded = false;
protected static final String[] SIMPLEREG_TEMPLATES = new String[] {
"models/regression/regDeclare.drlt",
"models/regression/regCommon.drlt",
"models/regression/regParams.drlt",
"models/regression/regEval.drlt",
"models/regression/regClaxOutput.drlt",
"models/regression/regNormalization.drlt",
"models/regression/regDecumulation.drlt",
};
protected static boolean clusteringLoaded = false;
protected static final String[] CLUSTERING_TEMPLATES = new String[] {
"models/clustering/clusteringDeclare.drlt",
"models/clustering/clusteringInit.drlt",
"models/clustering/clusteringEvalDistance.drlt",
"models/clustering/clusteringEvalSimilarity.drlt",
"models/clustering/clusteringMatrixCompare.drlt"
};
protected static boolean treeLoaded = false;
protected static final String[] TREE_TEMPLATES = new String[] {
"models/tree/treeDeclare.drlt",
"models/tree/treeCommon.drlt",
"models/tree/treeInputDeclare.drlt",
"models/tree/treeInit.drlt",
"models/tree/treeAggregateEval.drlt",
"models/tree/treeDefaultEval.drlt",
"models/tree/treeEval.drlt",
"models/tree/treeIOBinding.drlt",
"models/tree/treeMissHandleAggregate.drlt",
"models/tree/treeMissHandleWeighted.drlt",
"models/tree/treeMissHandleLast.drlt",
"models/tree/treeMissHandleNull.drlt",
"models/tree/treeMissHandleNone.drlt"
};
protected static boolean scorecardLoaded = false;
protected static final String[] SCORECARD_TEMPLATES = new String[] {
"models/scorecard/scorecardInit.drlt",
"models/scorecard/scorecardParamsInit.drlt",
"models/scorecard/scorecardDeclare.drlt",
"models/scorecard/scorecardDataDeclare.drlt",
"models/scorecard/scorecardPartialScore.drlt",
"models/scorecard/scorecardScoring.drlt",
"models/scorecard/scorecardOutputGeneration.drlt",
"models/scorecard/scorecardOutputRankCode.drlt"
};
protected static boolean informerLoaded = false;
protected static final String[] INFORMER_TEMPLATES = new String[] {
"informer/informer_imports.drlt",
"informer/modelQuestionnaire.drlt",
"informer/modelAddQuestionsToQuestionnaire.drlt",
"informer/modelQuestion.drlt",
"informer/modelMultiQuestion.drlt",
"informer/modelQuestionBinding.drlt",
"informer/modelQuestionRebinding.drlt" ,
"informer/modelCreateByBinding.drlt",
"informer/modelInvalidAnswer.drlt",
"informer/modelOutputBinding.drlt",
"informer/modelRevalidate.drlt"
};
protected static final String RESOURCE_PATH = BASE_PACK;
protected static final String TEMPLATE_PATH = "/" + RESOURCE_PATH + "/templates/";
private static TemplateRegistry registry;
private static KnowledgeBuilder kBuilder;
private static KnowledgeBase visitor;
private static List<KnowledgeBuilderResult> visitorBuildResults = new ArrayList<KnowledgeBuilderResult>();
private List<KnowledgeBuilderResult> results;
private PMML4Helper helper;
public PMML4Compiler() {
super();
results = new ArrayList<KnowledgeBuilderResult>();
helper = new PMML4Helper();
helper.setPack( BASE_TEST_PACK );
}
private static void initVisitor( PMML pmml ) throws IOException, IllegalStateException {
RuleBaseConfiguration conf = new RuleBaseConfiguration();
conf.setEventProcessingMode( EventProcessingOption.STREAM );
//conf.setConflictResolver(LifoConflictResolver.getInstance());
visitor = KnowledgeBaseFactory.newKnowledgeBase( conf );
// TODO before rules can be structured, I need to double-check the incremental rule base assembly
kBuilder = KnowledgeBuilderFactory.newKnowledgeBuilder( );
if ( visitorRules == false ) {
kBuilder.add( ResourceFactory.newClassPathResource( VISITOR_RULES ), ResourceType.DRL );
visitorRules = true;
}
if ( compilerRules == false ) {
kBuilder.add( ResourceFactory.newClassPathResource( COMPILER_RULES ), ResourceType.DRL );
compilerRules = true;
}
if ( informerRules == false && needsInformerExtension( pmml ) ) {
Resource res = ResourceFactory.newClassPathResource( INFORMER_RULES, PMML4Compiler.class );
try {
if ( res != null && (( ClassPathResource) res).getURL().openConnection().getContentType() != null ) {
kBuilder.add( res, ResourceType.DRL );
}
} catch ( IOException e ) {
e.printStackTrace();
visitorBuildResults.add( new PMMLError( e.getMessage() ) );
}
informerRules = true;
}
if ( kBuilder.hasErrors() ) {
visitorBuildResults.addAll( kBuilder.getErrors() );
} else {
visitor.addKnowledgePackages( kBuilder.getKnowledgePackages() );
}
}
public String generateTheory( PMML pmml ) {
StringBuilder sb = new StringBuilder();
try {
checkBuildingResources( pmml );
} catch ( IOException e ) {
this.results.add( new PMMLError( e.getMessage() ) );
return null;
}
StatefulKnowledgeSession visitorSession = visitor.newStatefulKnowledgeSession();
visitorSession.setGlobal( "registry", registry );
visitorSession.setGlobal( "fld2var", new HashMap() );
visitorSession.setGlobal( "utils", helper );
visitorSession.setGlobal( "theory", sb );
long now = System.currentTimeMillis();
visitorSession.insert( pmml );
visitorSession.fireAllRules();
long delta = System.currentTimeMillis() - now;
// System.out.println( "PMML compiled in " + delta );
String modelEvaluatingRules = sb.toString();
visitorSession.dispose();
return modelEvaluatingRules;
}
private static void initRegistry() {
if ( registry == null ) {
registry = new SimpleTemplateRegistry();
}
if ( ! globalLoaded ) {
for ( String ntempl : GLOBAL_TEMPLATES ) {
prepareTemplate( ntempl );
}
globalLoaded = true;
}
if ( ! transformationLoaded ) {
for ( String ntempl : TRANSFORMATION_TEMPLATES ) {
prepareTemplate( ntempl );
}
transformationLoaded = true;
}
if ( ! miningLoaded ) {
for ( String ntempl : MINING_TEMPLATES ) {
prepareTemplate( ntempl );
}
miningLoaded = true;
}
}
private static void checkBuildingResources( PMML pmml ) throws IOException {
if ( registry == null ) {
initRegistry();
}
if ( visitor == null ) {
initVisitor( pmml );
}
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( ! neuralLoaded && o instanceof NeuralNetwork ) {
for ( String ntempl : NEURAL_TEMPLATES ) {
prepareTemplate( ntempl );
}
neuralLoaded = true;
}
if ( ! clusteringLoaded && o instanceof ClusteringModel ) {
for ( String ntempl : CLUSTERING_TEMPLATES ) {
prepareTemplate( ntempl );
}
clusteringLoaded = true;
}
if ( ! svmLoaded && o instanceof SupportVectorMachineModel ) {
for ( String ntempl : SVM_TEMPLATES ) {
prepareTemplate( ntempl );
}
svmLoaded = true;
}
if ( ! treeLoaded && o instanceof TreeModel ) {
for ( String ntempl : TREE_TEMPLATES ) {
prepareTemplate( ntempl );
}
treeLoaded = true;
}
if ( ! simpleRegLoaded && o instanceof RegressionModel ) {
for ( String ntempl : SIMPLEREG_TEMPLATES ) {
prepareTemplate( ntempl );
}
simpleRegLoaded = true;
}
if ( ! scorecardLoaded && o instanceof Scorecard ) {
for ( String ntempl : SCORECARD_TEMPLATES ) {
prepareTemplate( ntempl );
}
scorecardLoaded = true;
}
}
if ( ! informerLoaded && needsInformerExtension( pmml ) ) {
if ( ! informerRules ) {
resetVisitor( pmml );
}
for ( String ntempl : INFORMER_TEMPLATES ) {
prepareTemplate( ntempl );
}
informerLoaded = true;
}
}
private static void resetVisitor( PMML pmml ) throws IOException {
visitor = null;
visitorRules = false;
compilerRules = false;
informerRules = false;
clusteringLoaded = false;
globalLoaded = false;
informerLoaded = false;
miningLoaded = false;
neuralLoaded = false;
scorecardLoaded = false;
simpleRegLoaded = false;
svmLoaded = false;
transformationLoaded = false;
treeLoaded = false;
try {
checkBuildingResources( pmml );
} catch ( IOException ioe ) {
visitorBuildResults.clear();
visitorBuildResults.add( new PMMLError( ioe.getMessage() ) );
}
}
protected static boolean needsInformerExtension( PMML pmml ) {
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
List inner;
if ( o instanceof AssociationModel ) {
inner = ((AssociationModel) o).getExtensionsAndMiningSchemasAndOutputs();
} else if ( o instanceof BaselineModel ) {
inner = ((BaselineModel) o).getExtensionsAndTestDistributionsAndMiningSchemas();
} else if ( o instanceof ClusteringModel ) {
inner = ((ClusteringModel) o).getExtensionsAndClustersAndComparisonMeasures();
} else if ( o instanceof GeneralRegressionModel ) {
inner = ((GeneralRegressionModel) o).getExtensionsAndParamMatrixesAndPPMatrixes();
} else if ( o instanceof MiningModel ) {
inner = ((MiningModel) o).getExtensionsAndMiningSchemasAndOutputs();
} else if ( o instanceof NaiveBayesModel ) {
inner = ((NaiveBayesModel) o).getExtensionsAndBayesOutputsAndBayesInputs();
} else if ( o instanceof NearestNeighborModel ) {
inner = ((NearestNeighborModel) o).getExtensionsAndKNNInputsAndComparisonMeasures();
} else if ( o instanceof NeuralNetwork ) {
inner = ((NeuralNetwork) o).getExtensionsAndNeuralLayersAndNeuralInputs();
} else if ( o instanceof RegressionModel ) {
inner = ((RegressionModel) o).getExtensionsAndRegressionTablesAndMiningSchemas();
} else if ( o instanceof RuleSetModel ) {
inner = ((RuleSetModel) o).getExtensionsAndRuleSetsAndMiningSchemas();
} else if ( o instanceof Scorecard ) {
inner = ((Scorecard) o).getExtensionsAndCharacteristicsAndMiningSchemas();
} else if ( o instanceof SequenceModel ) {
inner = ((SequenceModel) o).getExtensionsAndSequencesAndMiningSchemas();
} else if ( o instanceof SupportVectorMachineModel ) {
inner = ((SupportVectorMachineModel) o).getExtensionsAndSupportVectorMachinesAndVectorDictionaries();
} else if ( o instanceof TextModel ) {
inner = ((TextModel) o).getExtensionsAndDocumentTermMatrixesAndTextCorpuses();
} else if ( o instanceof TimeSeriesModel ) {
inner = ((TimeSeriesModel) o).getExtensionsAndMiningSchemasAndOutputs();
} else if ( o instanceof TreeModel ) {
inner = ((TreeModel) o).getExtensionsAndNodesAndMiningSchemas();
} else {
//should not happen
inner = Collections.emptyList();
}
for ( Object p : inner ) {
if ( p instanceof Extension ) {
Extension x = (Extension) p;
for ( Object c : x.getContent() ) {
if ( c instanceof Element && ((Element) c).getTagName().equals( "Surveyable" ) ) {
return true;
}
}
}
}
}
return false;
}
private static void prepareTemplate( String ntempl ) {
try {
String path = TEMPLATE_PATH + ntempl;
Resource res = ResourceFactory.newClassPathResource(path, PMML4Compiler.class);
if ( res != null ) {
InputStream stream = res.getInputStream();
if ( stream != null ) {
registry.addNamedTemplate( path.substring(path.lastIndexOf('/') + 1),
TemplateCompiler.compileTemplate(stream));
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
public String compile(String fileName, Map<String,PackageRegistry> registries) {
InputStream stream = Thread.currentThread().getContextClassLoader().getResourceAsStream( RESOURCE_PATH + "/" + fileName );
return compile(stream,registries);
}
public String compile(InputStream source, Map<String,PackageRegistry> registries) {
this.results = new ArrayList<KnowledgeBuilderResult>();
PMML pmml = loadModel( PMML, source );
if ( registries != null ) {
if ( registries.containsKey( helper.getPack() ) ) {
helper.setResolver( registries.get( helper.getPack() ).getTypeResolver() );
} else {
helper.setResolver( null );
}
}
if ( visitorBuildResults.isEmpty() && results.isEmpty() ) {
return generateTheory( pmml );
} else {
return null;
}
}
public List<KnowledgeBuilderResult> getResults() {
List<KnowledgeBuilderResult> combinedResults = new ArrayList<KnowledgeBuilderResult>( this.results );
combinedResults.addAll( visitorBuildResults );
return combinedResults;
}
public void clearResults() {
this.results.clear();
}
public void dump( String s, OutputStream ostream ) {
// write to outstream
Writer writer = null;
try {
writer = new OutputStreamWriter( ostream, "UTF-8" );
writer.write(s);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
finally {
try {
if (writer != null) {
writer.flush();
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* Imports a PMML source file, returning a Java descriptor
* @param model the PMML package name (classes derived from a specific schema)
* @param source the name of the PMML resource storing the predictive model
* @return the Java Descriptor of the PMML resource
*/
public PMML loadModel( String model, InputStream source ) {
try {
SchemaFactory sf = SchemaFactory.newInstance( XMLConstants.W3C_XML_SCHEMA_NS_URI );
Schema schema = null;
try {
schema = sf.newSchema( Thread.currentThread().getContextClassLoader().getResource( SCHEMA_PATH ) );
} catch ( SAXException e ) {
e.printStackTrace();
visitorBuildResults.add( new PMMLWarning( ResourceFactory.newInputStreamResource( source ), "Could not validate PMML document :" + e.getMessage() ) );
}
JAXBContext jc = JAXBContext.newInstance( model );
Unmarshaller unmarshaller = jc.createUnmarshaller();
if ( schema != null ) {
unmarshaller.setSchema( schema );
}
return (PMML) unmarshaller.unmarshal( source );
} catch ( JAXBException e ) {
this.results.add( new PMMLError( e.toString() ) );
return null;
}
}
public static void dumpModel( PMML model, OutputStream target ) {
try {
JAXBContext jc = JAXBContext.newInstance( PMML.class.getPackage().getName() );
Marshaller marshaller = jc.createMarshaller();
marshaller.setProperty( Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.TRUE );
marshaller.marshal( model, target );
} catch ( JAXBException e ) {
e.printStackTrace();
}
}
}