/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.pmml; import org.dmg.pmml.*; import org.encog.engine.network.activation.ActivationFunction; import org.encog.neural.flat.FlatNetwork; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; /** * The class that converts an Encog NeuralNetwork to a PMML RegressionModel. * This class extends the abstract class * PMMLModelBuilder(pmml.RegressionModel,Encog.NeuralNetwork). */ public class PMMLEncogNeuralNetworkModel implements PMMLModelBuilder<org.dmg.pmml.NeuralNetwork, org.encog.neural.networks.BasicNetwork> { private FlatNetwork network; /** * The function which converts an Encog NeuralNetwork to a PMML * NeuralNetwork Model. * <p> * This function reads the weights from the Encog NeuralNetwork model and assign them to the corresponding * connections of Neurons in PMML model. * * @param bNetwork * Encog NeuralNetwork * @param pmmlModel * DataFieldUtility that provides supplementary data field for * the model conversion * @return The generated PMML NeuralNetwork Model */ public org.dmg.pmml.NeuralNetwork adaptMLModelToPMML(org.encog.neural.networks.BasicNetwork bNetwork, org.dmg.pmml.NeuralNetwork pmmlModel) { network = bNetwork.getFlat(); pmmlModel = new NeuralNetworkModelIntegrator().adaptPMML(pmmlModel); int[] layerCount = network.getLayerCounts(); int[] layerFeedCount = network.getLayerFeedCounts(); double[] weights = network.getWeights(); ActivationFunctionType[] functionList = transformActivationFunction(network.getActivationFunctions()); int numLayers = layerCount.length; int weightID = 0; List<NeuralLayer> layerList = new ArrayList<NeuralLayer>(); pmmlModel.withFunctionName(MiningFunctionType.REGRESSION); for(int i = 0; i < numLayers - 1; i++) { NeuralLayer layer = new NeuralLayer(); layer.setNumberOfNeurons(layerFeedCount[i]); layer.setActivationFunction(functionList[i]); int layerID = numLayers - i - 1; for(int j = 0; j < layerFeedCount[i]; j++) { Neuron neuron = new Neuron(String.valueOf(layerID + "," + j)); for(int k = 0; k < layerFeedCount[i + 1]; k++) { neuron.withConnections(new Connection(String.valueOf(layerID - 1 + "," + k), weights[weightID++])); }// weights int tmp = layerCount[i + 1] - layerFeedCount[i + 1]; for(int k = 0; k < tmp; k++) { neuron.setBias(weights[weightID++]); } // bias neuron for each layer layer.withNeurons(neuron); }// finish build Neuron layerList.add(layer); }// finish build layer // reserve the layer list to fit fot PMML format Collections.reverse(layerList); pmmlModel.withNeuralLayers(layerList); // set neural output based on target id pmmlModel.withNeuralOutputs(PMMLAdapterCommonUtil.getOutputFields(pmmlModel.getMiningSchema(), numLayers - 1)); return pmmlModel; } private ActivationFunctionType[] transformActivationFunction(ActivationFunction[] functions) { int funLen = functions.length; ActivationFunctionType[] functionType = new ActivationFunctionType[funLen]; @SuppressWarnings("serial") HashMap<String, ActivationFunctionType> functionMap = new HashMap<String, ActivationFunctionType>() { { put("ActivationSigmoid", ActivationFunctionType.LOGISTIC); put("ActivationLinear", ActivationFunctionType.IDENTITY); put("ActivationTANH", ActivationFunctionType.TANH); } }; for(int i = 0; i < funLen; i++) { String trimS = functions[i].getClass().getName(); String[] functionS = trimS.split("\\."); functionType[i] = functionMap.get(functionS[functionS.length - 1]); } return functionType; } }