/* * Copyright (c) 2016 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package sklearn.neural_network; import java.util.ArrayList; import java.util.List; import com.google.common.collect.Iterables; import org.dmg.pmml.DataType; import org.dmg.pmml.DerivedField; import org.dmg.pmml.Entity; import org.dmg.pmml.FieldRef; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.NormDiscrete; import org.dmg.pmml.OpType; import org.dmg.pmml.neural_network.Connection; import org.dmg.pmml.neural_network.NeuralInput; import org.dmg.pmml.neural_network.NeuralInputs; import org.dmg.pmml.neural_network.NeuralLayer; import org.dmg.pmml.neural_network.NeuralNetwork; import org.dmg.pmml.neural_network.NeuralOutput; import org.dmg.pmml.neural_network.NeuralOutputs; import org.dmg.pmml.neural_network.Neuron; import org.jpmml.converter.CMatrixUtil; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.ContinuousLabel; import org.jpmml.converter.Feature; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.ValueUtil; import org.jpmml.sklearn.ClassDictUtil; import org.jpmml.sklearn.HasArray; public class NeuralNetworkUtil { private NeuralNetworkUtil(){ } static public int getNumberOfFeatures(List<? extends HasArray> coefs){ HasArray input = coefs.get(0); int[] shape = input.getArrayShape(); if(shape.length != 2){ throw new IllegalArgumentException(); } return shape[0]; } static public NeuralNetwork encodeNeuralNetwork(MiningFunction miningFunction, String activation, List<? extends HasArray> coefs, List<? extends HasArray> intercepts, Schema schema){ NeuralNetwork.ActivationFunction activationFunction = parseActivationFunction(activation); ClassDictUtil.checkSize(coefs, intercepts); NeuralInputs neuralInputs = new NeuralInputs(); List<Feature> features = schema.getFeatures(); for(int column = 0; column < features.size(); column++){ Feature feature = features.get(column); ContinuousFeature continuousFeature = feature.toContinuousFeature(); DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE) .setExpression(continuousFeature.ref()); NeuralInput neuralInput = new NeuralInput() .setId("0/" + (column + 1)) .setDerivedField(derivedField); neuralInputs.addNeuralInputs(neuralInput); } List<? extends Entity> entities = neuralInputs.getNeuralInputs(); List<NeuralLayer> neuralLayers = new ArrayList<>(); for(int layer = 0; layer < coefs.size(); layer++){ HasArray coef = coefs.get(layer); HasArray intercept = intercepts.get(layer); int[] shape = coef.getArrayShape(); int rows = shape[0]; int columns = shape[1]; List<Neuron> neurons = new ArrayList<>(); List<?> interceptVector = intercept.getArrayContent(); for(int column = 0; column < columns; column++){ Neuron neuron = new Neuron() .setId((layer + 1) + "/" + (column + 1)); Double bias = ValueUtil.asDouble((Number)interceptVector.get(column)); if(!ValueUtil.isZero(bias)){ neuron.setBias(bias); } neurons.add(neuron); } List<?> coefMatrix = coef.getArrayContent(); for(int row = 0; row < rows; row++){ List<?> weights = CMatrixUtil.getRow(coefMatrix, rows, columns, row); connect(entities.get(row), neurons, weights); } NeuralLayer neuralLayer = new NeuralLayer(neurons); if(layer == (coefs.size() - 1)){ neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY); switch(miningFunction){ case REGRESSION: break; case CLASSIFICATION: CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); // Binary classification if(categoricalLabel.size() == 2){ neuralLayers.add(neuralLayer); neuralLayer = encodeLogisticTransform(getOnlyNeuron(neuralLayer)); neuralLayers.add(neuralLayer); neuralLayer = encodeLabelBinarizerTransform(getOnlyNeuron(neuralLayer)); } else // Multi-class classification if(categoricalLabel.size() > 2){ neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX); } else { throw new IllegalArgumentException(); } break; default: break; } } entities = neuralLayer.getNeurons(); neuralLayers.add(neuralLayer); } NeuralOutputs neuralOutputs = null; switch(miningFunction){ case REGRESSION: neuralOutputs = encodeRegressionNeuralOutputs(entities, schema); break; case CLASSIFICATION: neuralOutputs = encodeClassificationNeuralOutputs(entities, schema); break; default: break; } NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, activationFunction, ModelUtil.createMiningSchema(schema), neuralInputs, neuralLayers) .setNeuralOutputs(neuralOutputs); return neuralNetwork; } static private NeuralLayer encodeLogisticTransform(Neuron input){ NeuralLayer neuralLayer = new NeuralLayer() .setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC); Neuron neuron = new Neuron() .setId("logistic/1") .setBias(0d) .addConnections(new Connection(input.getId(), 1d)); neuralLayer.addNeurons(neuron); return neuralLayer; } static private NeuralLayer encodeLabelBinarizerTransform(Neuron input){ NeuralLayer neuralLayer = new NeuralLayer() .setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY); Neuron noEventNeuron = new Neuron() .setId("event/false") .setBias(1d) .addConnections(new Connection(input.getId(), -1d)); Neuron eventNeuron = new Neuron() .setId("event/true") .setBias(0d) .addConnections(new Connection(input.getId(), 1d)); neuralLayer.addNeurons(noEventNeuron, eventNeuron); return neuralLayer; } static private NeuralOutputs encodeRegressionNeuralOutputs(List<? extends Entity> entities, Schema schema){ ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel(); ClassDictUtil.checkSize(1, entities); Entity entity = Iterables.getOnlyElement(entities); DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE) .setExpression(new FieldRef(continuousLabel.getName())); NeuralOutput neuralOutput = new NeuralOutput() .setOutputNeuron(entity.getId()) .setDerivedField(derivedField); NeuralOutputs neuralOutputs = new NeuralOutputs() .addNeuralOutputs(neuralOutput); return neuralOutputs; } static private NeuralOutputs encodeClassificationNeuralOutputs(List<? extends Entity> entities, Schema schema){ CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); ClassDictUtil.checkSize(categoricalLabel.size(), entities); NeuralOutputs neuralOutputs = new NeuralOutputs(); for(int i = 0; i < categoricalLabel.size(); i++){ Entity entity = entities.get(i); DerivedField derivedField = new DerivedField(OpType.CATEGORICAL, DataType.STRING) .setExpression(new NormDiscrete(categoricalLabel.getName(), categoricalLabel.getValue(i))); NeuralOutput neuralOutput = new NeuralOutput() .setOutputNeuron(entity.getId()) .setDerivedField(derivedField); neuralOutputs.addNeuralOutputs(neuralOutput); } return neuralOutputs; } static private void connect(Entity input, List<Neuron> neurons, List<?> weights){ ClassDictUtil.checkSize(neurons, weights); for(int i = 0; i < neurons.size(); i++){ Neuron neuron = neurons.get(i); Double weight = ValueUtil.asDouble((Number)weights.get(i)); neuron.addConnections(new Connection(input.getId(), weight)); } } static private Neuron getOnlyNeuron(NeuralLayer neuralLayer){ List<Neuron> neurons = neuralLayer.getNeurons(); return Iterables.getOnlyElement(neurons); } static private NeuralNetwork.ActivationFunction parseActivationFunction(String activation){ switch(activation){ case "identity": return NeuralNetwork.ActivationFunction.IDENTITY; case "logistic": return NeuralNetwork.ActivationFunction.LOGISTIC; case "relu": return NeuralNetwork.ActivationFunction.RECTIFIER; case "tanh": return NeuralNetwork.ActivationFunction.TANH; default: throw new IllegalArgumentException(activation); } } }