/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain; import java.io.IOException; import ml.shifu.shifu.core.dtrain.nn.NNParams; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.ml.data.basic.BasicMLDataSet; import org.encog.neural.error.LinearErrorFunction; import org.encog.neural.flat.FlatNetwork; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.BasicLayer; import org.encog.neural.networks.training.propagation.Propagation; import org.encog.neural.networks.training.propagation.back.Backpropagation; import org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation; import org.encog.neural.networks.training.propagation.quick.QuickPropagation; import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation; import org.encog.util.benchmark.RandomTrainingFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; public class DTrainTest { public static final int INPUT_COUNT = 1000; public static final int HIDDEN_COUNT = 20; public static final int OUTPUT_COUNT = 1; public BasicNetwork network; public MLDataSet training; public static final int NUM_EPOCHS = 20; public double rate = 0.5; public int numSplit = 24; public double[] weights; private final static Logger log = LoggerFactory.getLogger(DTrainTest.class); @BeforeTest public void setup() { network = new BasicNetwork(); network.addLayer(new BasicLayer(DTrainTest.INPUT_COUNT)); network.addLayer(new BasicLayer(DTrainTest.HIDDEN_COUNT)); network.addLayer(new BasicLayer(DTrainTest.OUTPUT_COUNT)); network.getStructure().finalizeStructure(); network.reset(); weights = network.getFlat().getWeights(); training = RandomTrainingFactory.generate(1000, 10000, INPUT_COUNT, OUTPUT_COUNT, -1, 1); } public Gradient initGradient(MLDataSet training) { FlatNetwork flat = network.getFlat().clone(); // copy Propagation from encog double[] flatSpot = new double[flat.getActivationFunctions().length]; for(int i = 0; i < flat.getActivationFunctions().length; i++) { final ActivationFunction af = flat.getActivationFunctions()[i]; if(af instanceof ActivationSigmoid) { flatSpot[i] = 0.1; } else { flatSpot[i] = 0.0; } } return new Gradient(flat, training.openAdditional(), training, flatSpot, new LinearErrorFunction(), false); } @Test public void quickTest() throws IOException { double[] gradientError = new double[NUM_EPOCHS]; double[] ecogError = new double[NUM_EPOCHS]; network.reset(); weights = network.getFlat().getWeights(); MLDataSet[] subsets = splitDataSet(training); Gradient[] workers = new Gradient[numSplit]; Weight weightCalculator = null; for(int i = 0; i < workers.length; i++) { workers[i] = initGradient(subsets[i]); workers[i].setWeights(weights); } log.info("Running QuickPropagtaion testing! "); NNParams globalParams = new NNParams(); globalParams.setWeights(weights); for(int i = 0; i < NUM_EPOCHS; i++) { double error = 0.0; // each worker do the job for(int j = 0; j < workers.length; j++) { workers[j].run(); error += workers[j].getError(); } gradientError[i] = error / workers.length; log.info("The #" + i + " training error: " + gradientError[i]); // master globalParams.reset(); for(int j = 0; j < workers.length; j++) { globalParams.accumulateGradients(workers[j].getGradients()); globalParams.accumulateTrainSize(subsets[j].getRecordCount()); } if(weightCalculator == null) { weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.QUICK_PROPAGATION, 0, RegulationLevel.NONE, 0d); } double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients()); globalParams.setWeights(interWeight); // set weights for(int j = 0; j < workers.length; j++) { workers[j].setWeights(interWeight); } } // encog network.reset(); // NNUtils.randomize(numSplit, weights); network.getFlat().setWeights(weights); Propagation p = null; p = new QuickPropagation(network, training, rate); // p = new ManhattanPropagation(network, training, rate); p.setThreadCount(numSplit); for(int i = 0; i < NUM_EPOCHS; i++) { p.iteration(1); // System.out.println("the #" + i + " training error: " + p.getError()); ecogError[i] = p.getError(); } // assert double diff = 0.0; for(int i = 0; i < NUM_EPOCHS; i++) { diff += Math.abs(ecogError[i] - gradientError[i]); } Assert.assertTrue(diff / NUM_EPOCHS < 0.1); } private MLDataSet[] splitDataSet(MLDataSet data) { MLDataSet[] subsets = new MLDataSet[numSplit]; for(int i = 0; i < subsets.length; i++) { subsets[i] = new BasicMLDataSet(); } for(int i = 0; i < data.getRecordCount(); i++) { MLDataPair pair = BasicMLDataPair.createPair(INPUT_COUNT, OUTPUT_COUNT); data.getRecord(i, pair); subsets[i % numSplit].add(pair); } return subsets; } @Test public void manhantTest() throws IOException { double[] gradientError = new double[NUM_EPOCHS]; double[] ecogError = new double[NUM_EPOCHS]; network.reset(); weights = network.getFlat().getWeights(); MLDataSet[] subsets = splitDataSet(training); Gradient[] workers = new Gradient[numSplit]; Weight weightCalculator = null; for(int i = 0; i < workers.length; i++) { workers[i] = initGradient(subsets[i]); workers[i].setWeights(weights); } NNParams globalParams = new NNParams(); globalParams.setWeights(weights); log.info("Starting manhattan propagation testing!"); for(int i = 0; i < NUM_EPOCHS; i++) { double error = 0.0; // each worker do the job for(int j = 0; j < workers.length; j++) { workers[j].run(); error += workers[j].getError(); } gradientError[i] = error / workers.length; log.info("The #" + i + " training error: " + gradientError[i]); // master globalParams.reset(); for(int j = 0; j < workers.length; j++) { globalParams.accumulateGradients(workers[j].getGradients()); globalParams.accumulateTrainSize(subsets[j].getRecordCount()); } if(weightCalculator == null) { weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.MANHATTAN_PROPAGATION, 0, RegulationLevel.NONE, 0d); } double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients()); globalParams.setWeights(interWeight); // set weights for(int j = 0; j < workers.length; j++) { workers[j].setWeights(interWeight); } } // encog network.reset(); // NNUtils.randomize(numSplit, weights); network.getFlat().setWeights(weights); Propagation p = null; p = new ManhattanPropagation(network, training, rate); p.setThreadCount(numSplit); for(int i = 0; i < NUM_EPOCHS; i++) { p.iteration(1); // System.out.println("the #" + i + " training error: " + p.getError()); ecogError[i] = p.getError(); } // assert double diff = 0.0; for(int i = 0; i < NUM_EPOCHS; i++) { diff += Math.abs(ecogError[i] - gradientError[i]); } Assert.assertTrue(diff / NUM_EPOCHS < 0.3); } @Test public void backTest() { double[] gradientError = new double[NUM_EPOCHS]; double[] ecogError = new double[NUM_EPOCHS]; network.reset(); weights = network.getFlat().getWeights(); MLDataSet[] subsets = splitDataSet(training); Gradient[] workers = new Gradient[numSplit]; Weight weightCalculator = null; for(int i = 0; i < workers.length; i++) { workers[i] = initGradient(subsets[i]); workers[i].setWeights(weights); } log.info("Starting back propagation testing!"); NNParams globalParams = new NNParams(); globalParams.setWeights(weights); for(int i = 0; i < NUM_EPOCHS; i++) { double error = 0.0; // each worker do the job for(int j = 0; j < workers.length; j++) { workers[j].run(); error += workers[j].getError(); } gradientError[i] = error / workers.length; log.info("The #" + i + " training error: " + gradientError[i]); // master globalParams.reset(); for(int j = 0; j < workers.length; j++) { globalParams.accumulateGradients(workers[j].getGradients()); globalParams.accumulateTrainSize(subsets[j].getRecordCount()); } if(weightCalculator == null) { weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.BACK_PROPAGATION, 0, RegulationLevel.NONE, 0d); } double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients()); globalParams.setWeights(interWeight); // set weights for(int j = 0; j < workers.length; j++) { workers[j].setWeights(interWeight); } } // encog network.reset(); // NNUtils.randomize(numSplit, weights); network.getFlat().setWeights(weights); Propagation p = null; p = new Backpropagation(network, training, rate, 0.5); p.setThreadCount(numSplit); for(int i = 0; i < NUM_EPOCHS; i++) { p.iteration(1); // System.out.println("the #" + i + " training error: " + p.getError()); ecogError[i] = p.getError(); } // assert double diff = 0.0; for(int i = 0; i < NUM_EPOCHS; i++) { diff += Math.abs(ecogError[i] - gradientError[i]); } Assert.assertTrue(diff / NUM_EPOCHS < 0.2); } @Test public void resilientPropagationTest() { double[] gradientError = new double[NUM_EPOCHS]; double[] ecogError = new double[NUM_EPOCHS]; network.reset(); weights = network.getFlat().getWeights(); MLDataSet[] subsets = splitDataSet(training); Gradient[] workers = new Gradient[numSplit]; Weight weightCalculator = null; for(int i = 0; i < workers.length; i++) { workers[i] = initGradient(subsets[i]); workers[i].setWeights(weights); } log.info("Starting resilient propagation testing!"); NNParams globalParams = new NNParams(); globalParams.setWeights(weights); for(int i = 0; i < NUM_EPOCHS; i++) { double error = 0.0; // each worker do the job for(int j = 0; j < workers.length; j++) { workers[j].run(); error += workers[j].getError(); } gradientError[i] = error / workers.length; log.info("The #" + i + " training error: " + gradientError[i]); // master globalParams.reset(); for(int j = 0; j < workers.length; j++) { globalParams.accumulateGradients(workers[j].getGradients()); globalParams.accumulateTrainSize(subsets[j].getRecordCount()); } if(weightCalculator == null) { weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.RESILIENTPROPAGATION, 0, RegulationLevel.NONE, 0d); } double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients()); globalParams.setWeights(interWeight); // set weights for(int j = 0; j < workers.length; j++) { workers[j].setWeights(interWeight); } } // encog network.reset(); // NNUtils.randomize(numSplit, weights); network.getFlat().setWeights(weights); Propagation p = null; p = new ResilientPropagation(network, training); p.setThreadCount(numSplit); for(int i = 0; i < NUM_EPOCHS; i++) { p.iteration(1); ecogError[i] = p.getError(); } // assert double diff = 0.0; for(int i = 0; i < NUM_EPOCHS; i++) { diff += Math.abs(ecogError[i] - gradientError[i]); } Assert.assertTrue(diff / NUM_EPOCHS < 0.2); } }