/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.nn; import java.util.Arrays; import java.util.concurrent.Callable; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair; import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork; import ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork; import ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair; import ml.shifu.shifu.core.dtrain.dataset.FloatMLDataSet; import org.encog.engine.network.activation.ActivationFunction; import org.encog.mathutil.error.ErrorCalculation; import org.encog.neural.error.ErrorFunction; import org.encog.neural.flat.FlatNetwork; /** * {@link SubGradient} is copied from Encog framework. The reason is that we original Gradient don't pop up * {@link #gradients} outside. While we need gradients accumulated into {@link NNMaster} to update NN weights. */ public class SubGradient implements Callable<double[]> { /** * The network to train. */ private FloatFlatNetwork network; /** * The error calculation method. */ private final ErrorCalculation errorCalculation = new ErrorCalculation(); /** * The actual values from the neural network. */ private double[] actual; /** * The deltas for each layer. */ private double[] layerDelta; /** * The neuron counts, per layer. */ private int[] layerCounts; /** * The feed counts, per layer. */ private int[] layerFeedCounts; /** * The layer indexes. */ private int[] layerIndex; /** * The index to each layer's weights and thresholds. */ private int[] weightIndex; /** * The output from each layer. */ private double[] layerOutput; /** * The sums. */ private double[] layerSums; /** * The gradients. */ private double[] gradients; /** * The weights and thresholds. */ private double[] weights; /** * The pair to use for training. */ private FloatMLDataPair pair; /** * The training data. */ private final FloatMLDataSet training; /** * The testing data, test data set here is used for training and testing cross over. */ private FloatMLDataSet testing; /** * Whether to replace training and testing elements. */ private final boolean isCrossOver; /** * Seed used to sample training and testing data set to choose which element is used for training */ private long seed = System.currentTimeMillis(); /** * error */ private double error; /** * Derivative add constant. Used to combat flat spot. */ private double[] flatSpot; /** * The error function to use. */ private final ErrorFunction errorFunction; private final long trainLow; private final long trainHigh; private final long testLow; private final long testHigh; private ParallelGradient owner; private double[] doubleIdeal; public SubGradient(final FloatFlatNetwork theNetwork, final FloatMLDataSet theTraining, long trainLow, long trainHigh, final FloatMLDataSet theTesting, long testLow, long testHigh, final double[] flatSpot, ErrorFunction ef, boolean isCrossOver, ParallelGradient owner) { this.network = theNetwork; this.training = theTraining; this.trainLow = trainLow; this.trainHigh = trainHigh; this.testing = theTesting; this.testLow = testLow; this.testHigh = testHigh; this.isCrossOver = isCrossOver; this.flatSpot = flatSpot; this.errorFunction = ef; this.owner = owner; this.initNetworkParams(); } private void initNetworkParams() { this.layerDelta = new double[this.network.getLayerOutput().length]; this.gradients = new double[this.network.getWeights().length]; this.actual = new double[this.network.getOutputCount()]; this.weights = this.network.getWeights(); this.layerIndex = this.network.getLayerIndex(); this.layerCounts = this.network.getLayerCounts(); this.weightIndex = this.network.getWeightIndex(); this.layerOutput = this.network.getLayerOutput(); this.layerSums = this.network.getLayerSums(); this.layerFeedCounts = this.network.getLayerFeedCounts(); this.pair = BasicFloatMLDataPair.createPair(this.network.getInputCount(), getNetwork().getOutputCount()); } /** * Process one training set element. * * @param input * The network input. * @param ideal * The ideal values. * @param s * The significance. */ private void process(final float[] input, final float[] ideal, double s) { ((FloatFlatNetwork) this.getNetwork()).compute(input, this.actual); // have to copy float ideal array to double array, since ideal array is small, it's ok to copy an array if(doubleIdeal == null) { doubleIdeal = new double[ideal.length]; } for(int i = 0; i < doubleIdeal.length; i++) { doubleIdeal[i] = ideal[i]; } this.errorCalculation.updateError(this.actual, doubleIdeal, s); this.errorFunction.calculateError(doubleIdeal, actual, this.getLayerDelta()); for(int i = 0; i < this.actual.length; i++) { this.getLayerDelta()[i] = ((this.getNetwork().getActivationFunctions()[0].derivativeFunction( this.layerSums[i], this.layerOutput[i]) + this.flatSpot[0])) * (this.getLayerDelta()[i] * s); } int beginTraining = this.getNetwork().getBeginTraining(); for(int i = beginTraining; i < this.getNetwork().getEndTraining(); i++) { processLevel(i); } } /** * Process one level. * * @param currentLevel * The level. */ private void processLevel(final int currentLevel) { final int fromLayerIndex = this.layerIndex[currentLevel + 1]; final int toLayerIndex = this.layerIndex[currentLevel]; final int fromLayerSize = this.layerCounts[currentLevel + 1]; final int toLayerSize = this.layerFeedCounts[currentLevel]; final int index = this.weightIndex[currentLevel]; final ActivationFunction activation = this.getNetwork().getActivationFunctions()[currentLevel + 1]; final double currentFlatSpot = this.flatSpot[currentLevel + 1]; // handle weights int yi = fromLayerIndex; for(int y = 0; y < fromLayerSize; y++) { final double output = this.layerOutput[yi]; double sum = 0; int xi = toLayerIndex; int wi = index + y; for(int x = 0; x < toLayerSize; x++) { if(this.owner.isELM() && currentLevel == 0) { this.gradients[wi] = 0d; } else { this.gradients[wi] += output * this.getLayerDelta()[xi]; } sum += this.weights[wi] * this.getLayerDelta()[xi]; wi += fromLayerSize; xi++; } this.getLayerDelta()[yi] = sum * (activation.derivativeFunction(this.layerSums[yi], this.layerOutput[yi]) + currentFlatSpot); yi++; } } /** * Perform the gradient calculation */ public final double[] call() { try { // reset errors and gradients firstly this.errorCalculation.reset(); Arrays.fill(this.gradients, 0.0); for(long i = this.trainLow; i <= this.trainHigh; i++) { synchronized(this.owner) { if(this.isCrossOver) { // 3:1 to select testing data set, tmp hard code, TODO fix hard code issue,extract such logic to // a method if((i + seed) % 4 < 3) { this.training.getRecord(i, this.pair); } else { long testingSize = this.testing.getRecordCount(); // it's ok to take data from all testing set if(i < testingSize) { this.testing.getRecord(i, this.pair); } else { this.testing.getRecord(i % testingSize, this.pair); } } } else { this.training.getRecord(i, this.pair); } } process(this.pair.getInputArray(), this.pair.getIdealArray(), pair.getSignificance()); } this.error = this.errorCalculation.calculate(); } catch (final Throwable ex) { throw new RuntimeException(ex); } return this.gradients; } /** * Calculate the error for this neural network. The error is calculated * using root-mean-square(RMS). * * @param ec * The error computation logic * @return The error percentage. */ public final double calculateError(ErrorCalculation ec) { final double[] actual = new double[this.getNetwork().getOutputCount()]; final FloatMLDataPair pair = BasicFloatMLDataPair.createPair(testing.getInputSize(), testing.getIdealSize()); for(long i = testLow; i <= testHigh; i++) { synchronized(this.owner) { if(this.isCrossOver) { // 3:1 to select testing data set, tmp hard code, TODO fix hard code issue if((i + seed) % 4 < 3) { this.testing.getRecord(i, pair); } else { long trainingSize = this.training.getRecordCount(); // it's ok to take data from all training set if(i < trainingSize) { this.training.getRecord(i, pair); } else { this.training.getRecord(i % trainingSize, pair); } } } else { this.testing.getRecord(i, pair); } } ((FloatFlatNetwork) this.getNetwork()).compute(pair.getInputArray(), actual); // copy float idea array to double for api compatiability if(doubleIdeal == null) { doubleIdeal = new double[pair.getIdealArray().length]; } for(int j = 0; j < doubleIdeal.length; j++) { doubleIdeal[j] = pair.getIdealArray()[j]; } synchronized(ec) { ec.updateError(actual, doubleIdeal, pair.getSignificance()); } } return -1; } public ErrorCalculation getErrorCalculation() { return errorCalculation; } /** * @return the gradients */ public double[] getGradients() { return this.gradients; } /** * @return the error */ public double getError() { return error; } /** * @return the weights */ public double[] getWeights() { return weights; } /** * @param weights * the weights to set */ public void setWeights(double[] weights) { this.weights = weights; this.getNetwork().setWeights(weights); } public void setParams(BasicFloatNetwork network) { this.setNetwork((FloatFlatNetwork) network.getFlat()); this.weights = network.getFlat().getWeights(); } public FlatNetwork getNetwork() { return network; } public double[] getLayerDelta() { return layerDelta; } /** * @return the seed */ public long getSeed() { return seed; } /** * @param seed * the seed to set */ public void setSeed(long seed) { this.seed = seed; } /** * @param network * the network to set */ public void setNetwork(FloatFlatNetwork network) { this.network = network; this.weights = this.network.getWeights(); } }