/*
* Copyright 2011 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.pmml.pmml_4_1;
import org.dmg.pmml.pmml_4_1.descr.*;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import java.io.OutputStream;
import java.io.Writer;
import java.math.BigInteger;
public class PMMLGeneratorUtils {
private static Marshaller initContext( PMML pmml ) throws JAXBException {
JAXBContext pContext = JAXBContext.newInstance(pmml.getClass().getPackage().getName());
Marshaller marshaller = pContext.createMarshaller();
marshaller.setProperty( Marshaller.JAXB_ENCODING, "UTF-8");
marshaller.setProperty( Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.TRUE );
return marshaller;
}
public static boolean streamPMML( PMML pmml, OutputStream out ) {
try {
Marshaller marshaller = initContext( pmml );
marshaller.marshal( pmml, out );
return true;
} catch ( JAXBException e ) {
e.printStackTrace();
}
return false;
}
public static boolean streamPMML( PMML pmml, Writer out ) {
try {
Marshaller marshaller = initContext( pmml );
marshaller.marshal( pmml, out );
return true;
} catch ( JAXBException e ) {
e.printStackTrace();
}
return false;
}
public static PMML generateSimpleNeuralNetwork( String modelName,
String[] inputfieldNames, String[] outputfieldNames,
double[] inputMeans, double[] inputStds,
double[] outputMeans, double[] outputStds,
int hiddenSize,
double[] weights ) {
int counter = 0;
int wtsIndex = 0;
PMML pmml = new PMML();
pmml.setVersion("4.0");
Header header = new Header();
Application app = new Application();
app.setName( "Drools PMML Generator" );
app.setVersion( "0.01 Alpha" );
header.setApplication( app );
header.setCopyright("BSD");
header.setDescription(" Smart Vent Model ");
Timestamp ts = new Timestamp();
ts.getContent().add( new java.util.Date().toString() );
header.setTimestamp( ts );
pmml.setHeader( header );
DataDictionary dic = new DataDictionary();
dic.setNumberOfFields( BigInteger.valueOf( inputfieldNames.length + outputfieldNames.length ) );
for ( String ifld : inputfieldNames ) {
DataField dataField = new DataField();
dataField.setName( ifld );
dataField.setDataType( DATATYPE.DOUBLE );
dataField.setDisplayName( ifld );
dataField.setOptype( OPTYPE.CONTINUOUS );
dic.getDataFields().add( dataField );
}
for ( String ofld : outputfieldNames ) {
DataField dataField = new DataField();
dataField.setName( ofld );
dataField.setDataType( DATATYPE.DOUBLE );
dataField.setDisplayName( ofld );
dataField.setOptype( OPTYPE.CONTINUOUS );
dic.getDataFields().add( dataField );
}
pmml.setDataDictionary(dic);
NeuralNetwork nnet = new NeuralNetwork();
nnet.setActivationFunction( ACTIVATIONFUNCTION.LOGISTIC );
nnet.setFunctionName( MININGFUNCTION.REGRESSION );
nnet.setNormalizationMethod( NNNORMALIZATIONMETHOD.NONE );
nnet.setModelName( modelName );
MiningSchema miningSchema = new MiningSchema();
for ( String ifld : inputfieldNames ) {
MiningField mfld = new MiningField();
mfld.setName( ifld );
mfld.setOptype( OPTYPE.CONTINUOUS );
mfld.setUsageType( FIELDUSAGETYPE.ACTIVE );
miningSchema.getMiningFields().add( mfld );
}
for ( String ofld : outputfieldNames ) {
MiningField mfld = new MiningField();
mfld.setName( ofld );
mfld.setOptype( OPTYPE.CONTINUOUS );
mfld.setUsageType( FIELDUSAGETYPE.PREDICTED );
miningSchema.getMiningFields().add( mfld );
}
nnet.getExtensionsAndNeuralLayersAndNeuralInputs().add( miningSchema );
Output outputs = new Output();
for ( String ofld : outputfieldNames ) {
OutputField outFld = new OutputField();
outFld.setName( "Out_" + ofld );
outFld.setTargetField( ofld );
outputs.getOutputFields().add( outFld );
}
nnet.getExtensionsAndNeuralLayersAndNeuralInputs().add( outputs );
NeuralInputs nins = new NeuralInputs();
nins.setNumberOfInputs( BigInteger.valueOf( inputfieldNames.length ) );
for ( int j = 0; j < inputfieldNames.length; j++ ) {
String ifld = inputfieldNames[j];
NeuralInput nin = new NeuralInput();
nin.setId( "" + counter++ );
DerivedField der = new DerivedField();
der.setDataType( DATATYPE.DOUBLE );
der.setOptype( OPTYPE.CONTINUOUS );
NormContinuous nc = new NormContinuous();
nc.setField( ifld );
nc.setOutliers( OUTLIERTREATMENTMETHOD.AS_IS );
LinearNorm lin1 = new LinearNorm();
lin1.setOrig( 0 );
lin1.setNorm( - inputMeans[j] / inputStds[j] );
nc.getLinearNorms().add( lin1 );
LinearNorm lin2 = new LinearNorm();
lin2.setOrig( inputMeans[j] );
lin2.setNorm( 0 );
nc.getLinearNorms().add( lin2 );
der.setNormContinuous( nc );
nin.setDerivedField( der );
nins.getNeuralInputs().add(nin);
}
nnet.getExtensionsAndNeuralLayersAndNeuralInputs().add( nins );
NeuralLayer hidden = new NeuralLayer();
hidden.setNumberOfNeurons( BigInteger.valueOf( hiddenSize ));
for ( int j = 0; j < hiddenSize; j ++ ) {
Neuron n = new Neuron();
n.setId( "" + counter++ );
n.setBias( weights[ wtsIndex++ ] );
for ( int k = 0; k < inputfieldNames.length; k++ ) {
Synapse con = new Synapse();
con.setFrom( "" + k );
con.setWeight( weights[ wtsIndex++ ] );
n.getCons().add( con );
}
hidden.getNeurons().add( n );
}
nnet.getExtensionsAndNeuralLayersAndNeuralInputs().add( hidden );
NeuralLayer outer = new NeuralLayer();
outer.setActivationFunction( ACTIVATIONFUNCTION.IDENTITY );
outer.setNumberOfNeurons( BigInteger.valueOf( outputfieldNames.length ));
for ( int j = 0; j < outputfieldNames.length; j ++ ) {
Neuron n = new Neuron();
n.setId( "" + counter++ );
n.setBias( weights[ wtsIndex++ ] );
for ( int k = 0; k < hiddenSize; k++ ) {
Synapse con = new Synapse();
con.setFrom( "" + ( k + inputfieldNames.length ) );
con.setWeight( weights[ wtsIndex++ ] );
n.getCons().add( con );
}
outer.getNeurons().add( n );
}
nnet.getExtensionsAndNeuralLayersAndNeuralInputs().add( outer );
NeuralOutputs finalOuts = new NeuralOutputs();
finalOuts.setNumberOfOutputs( BigInteger.valueOf( outputfieldNames.length ) );
for ( int j = 0; j < outputfieldNames.length; j ++ ) {
NeuralOutput output = new NeuralOutput();
output.setOutputNeuron( ""+ ( j + inputfieldNames.length + hiddenSize ) );
DerivedField der = new DerivedField();
der.setDataType( DATATYPE.DOUBLE );
der.setOptype( OPTYPE.CONTINUOUS );
NormContinuous nc = new NormContinuous();
nc.setField( outputfieldNames[j] );
nc.setOutliers( OUTLIERTREATMENTMETHOD.AS_IS );
LinearNorm lin1 = new LinearNorm();
lin1.setOrig( 0 );
lin1.setNorm( - outputMeans[j] / outputStds[j] );
nc.getLinearNorms().add( lin1 );
LinearNorm lin2 = new LinearNorm();
lin2.setOrig( outputMeans[j] );
lin2.setNorm( 0 );
nc.getLinearNorms().add( lin2 );
der.setNormContinuous( nc );
output.setDerivedField( der );
finalOuts.getNeuralOutputs().add( output );
}
nnet.getExtensionsAndNeuralLayersAndNeuralInputs().add( finalOuts );
pmml.getAssociationModelsAndBaselineModelsAndClusteringModels().add( nnet );
return pmml;
}
}