/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.pmml;
import org.dmg.pmml.*;
import java.util.HashMap;
import java.util.List;
/**
* This class glues the partial PMML neural network model with the neural layers
* part, by adding bias field and neural input layer.
*/
public class NeuralNetworkModelIntegrator {
/**
* Given the partial neural network model, return the neural network model
* by adding bias neuron and neural input layer.
*
* @param model
* the nn model
* @return nn model after adaptored
*/
public NeuralNetwork adaptPMML(NeuralNetwork model) {
model.withNeuralInputs(getNeuralInputs(model));
model.setLocalTransformations(getLocalTranformations(model));
return model;
}
private NeuralInputs getNeuralInputs(final NeuralNetwork model) {
NeuralInputs nnInputs = new NeuralInputs();
// get HashMap for local transform and MiningSchema fields
HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
for(DerivedField dField: model.getLocalTransformations().getDerivedFields()) {
// Apply z-scale normalization on numerical variables
if(dField.getExpression() instanceof NormContinuous) {
miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
}
// Apply bin map on categorical variables
else if(dField.getExpression() instanceof MapValues) {
miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(),
dField.getName());
} else if(dField.getExpression() instanceof Discretize) {
miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
}
}
List<MiningField> miningList = model.getMiningSchema().getMiningFields();
int index = 0;
for(int i = 0; i < miningList.size(); i++) {
MiningField mField = miningList.get(i);
if(mField.getUsageType() != FieldUsageType.ACTIVE)
continue;
FieldName mFieldName = mField.getName();
FieldName fName = mFieldName;
while(miningTransformMap.containsKey(fName)) {
fName = miningTransformMap.get(fName);
}
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).withName(fName).withExpression(
new FieldRef(fName));
nnInputs.withNeuralInputs(new NeuralInput(field, "0," + (index++)));
}
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).withName(
new FieldName(PluginConstants.biasValue)).withExpression(
new FieldRef(new FieldName(PluginConstants.biasValue)));
nnInputs.withNeuralInputs(new NeuralInput(field, PluginConstants.biasValue));
return nnInputs;
}
private LocalTransformations getLocalTranformations(NeuralNetwork model) {
// delete target
List<DerivedField> derivedFields = model.getLocalTransformations().getDerivedFields();
// add bias
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).withName(new FieldName(
PluginConstants.biasValue));
field.withExpression(new Constant(String.valueOf(PluginConstants.bias)));
derivedFields.add(field);
return new LocalTransformations().withDerivedFields(derivedFields);
}
}