package org.deeplearning4j.earlystopping; import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener; import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; import org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition; import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition; import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition; import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.Sin; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.TimeUnit; import static org.junit.Assert.*; public class TestEarlyStopping { @Test public void testEarlyStoppingIris() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(5)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)) .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver) .build(); IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter); EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit(); System.out.println(result); assertEquals(5, result.getTotalEpochs()); assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason()); Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch(); assertEquals(5, scoreVsIter.size()); String expDetails = esConf.getEpochTerminationConditions().get(0).toString(); assertEquals(expDetails, result.getTerminationDetails()); MultiLayerNetwork out = result.getBestModel(); assertNotNull(out); //Check that best score actually matches (returned model vs. manually calculated score) MultiLayerNetwork bestNetwork = result.getBestModel(); irisIter.reset(); double score = bestNetwork.score(irisIter.next()); assertEquals(result.getBestModelScore(), score, 1e-2); } @Test public void testEarlyStoppingEveryNEpoch() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(5)) .scoreCalculator(new DataSetLossCalculator(irisIter, true)) .evaluateEveryNEpochs(2).modelSaver(saver).build(); IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter); EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit(); System.out.println(result); assertEquals(5, result.getTotalEpochs()); assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason()); } @Test public void testEarlyStoppingIrisMultiEpoch() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(5)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)) .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver) .build(); IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter); EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit(); System.out.println(result); assertEquals(5, result.getTotalEpochs()); assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason()); Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch(); assertEquals(5, scoreVsIter.size()); String expDetails = esConf.getEpochTerminationConditions().get(0).toString(); assertEquals(expDetails, result.getTerminationDetails()); MultiLayerNetwork out = result.getBestModel(); assertNotNull(out); //Check that best score actually matches (returned model vs. manually calculated score) MultiLayerNetwork bestNetwork = result.getBestModel(); irisIter.reset(); double score = bestNetwork.score(irisIter.next()); assertEquals(result.getBestModelScore(), score, 1e-2); } @Test public void testBadTuning() { //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).learningRate(5.0) //Intentionally huge LR .weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(5000)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), new MaxScoreIterationTerminationCondition(10)) //Initial score is ~2.5 .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver) .build(); IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter); EarlyStoppingResult result = trainer.fit(); assertTrue(result.getTotalEpochs() < 5); assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason()); String expDetails = new MaxScoreIterationTerminationCondition(10).toString(); assertEquals(expDetails, result.getTerminationDetails()); assertEquals(0, result.getBestModelEpoch()); assertNotNull(result.getBestModel()); } @Test public void testTimeTermination() { //test termination after max time Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(10000)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), new MaxScoreIterationTerminationCondition(7.5)) //Initial score is ~2.5 //.scoreCalculator(new DataSetLossCalculator(irisIter, true)) //No score calculator in this test (don't need score) .modelSaver(saver).build(); IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter); long startTime = System.currentTimeMillis(); EarlyStoppingResult result = trainer.fit(); long endTime = System.currentTimeMillis(); int durationSeconds = (int) (endTime - startTime) / 1000; assertTrue(durationSeconds >= 3); assertTrue(durationSeconds <= 9); assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason()); String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString(); assertEquals(expDetails, result.getTerminationDetails()); } @Test public void testNoImprovementNEpochsTermination() { //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs //Simulate this by setting LR = 0.0 Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), new MaxScoreIterationTerminationCondition(7.5)) //Initial score is ~2.5 .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver) .build(); IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter); EarlyStoppingResult result = trainer.fit(); //Expect no score change due to 0 LR -> terminate after 6 total epochs assertEquals(6, result.getTotalEpochs()); assertEquals(0, result.getBestModelEpoch()); assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason()); String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString(); assertEquals(expDetails, result.getTerminationDetails()); } @Test public void testMinImprovementNEpochsTermination() { //Idea: terminate training if score (test set loss) does not improve more than minImprovement for 5 consecutive epochs //Simulate this by setting LR = 0.0 Random rng = new Random(123); Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).iterations(10) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.0) .updater(Updater.NESTEROVS).momentum(0.9).list() .layer(0, new DenseLayer.Builder().nIn(1).nOut(20) .weightInit(WeightInit.XAVIER).activation( Activation.TANH) .build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).nIn(20).nOut(1) .build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); int nSamples = 100; //Generate the training data INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1); INDArray y = Nd4j.getExecutioner().execAndReturn(new Sin(x.dup())); DataSet allData = new DataSet(x, y); List<DataSet> list = allData.asList(); Collections.shuffle(list, rng); DataSetIterator training = new ListDataSetIterator(list, nSamples); double minImprovement = 0.0009; EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(1000), //Go on for max 5 epochs without any improvements that are greater than minImprovement new ScoreImprovementEpochTerminationCondition(5, minImprovement)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(3, TimeUnit.MINUTES)) .scoreCalculator(new DataSetLossCalculator(training, true)).modelSaver(saver) .build(); IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, training); EarlyStoppingResult result = trainer.fit(); assertEquals(6, result.getTotalEpochs()); assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason()); String expDetails = new ScoreImprovementEpochTerminationCondition(5, minImprovement).toString(); assertEquals(expDetails, result.getTerminationDetails()); } @Test public void testEarlyStoppingGetBestModel() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(5)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)) .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver) .build(); IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter); EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit(); System.out.println(result); MultiLayerNetwork mln = result.getBestModel(); assertEquals(net.getnLayers(), mln.getnLayers()); assertEquals(net.conf().getNumIterations(), mln.conf().getNumIterations()); assertEquals(net.conf().getOptimizationAlgo(), mln.conf().getOptimizationAlgo()); assertEquals(net.conf().getLayer().getActivationFn().toString(), mln.conf().getLayer().getActivationFn().toString()); assertEquals(net.conf().getLayer().getUpdater(), mln.conf().getLayer().getUpdater()); } @Test public void testListeners() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.SGD).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>() .epochTerminationConditions(new MaxEpochsTerminationCondition(5)) .iterationTerminationConditions( new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)) .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver) .build(); LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener(); IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter, listener); trainer.fit(); assertEquals(1, listener.onStartCallCount); assertEquals(5, listener.onEpochCallCount); assertEquals(1, listener.onCompletionCallCount); } private static class LoggingEarlyStoppingListener implements EarlyStoppingListener<MultiLayerNetwork> { private static Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); private int onStartCallCount = 0; private int onEpochCallCount = 0; private int onCompletionCallCount = 0; @Override public void onStart(EarlyStoppingConfiguration esConfig, MultiLayerNetwork net) { log.info("EarlyStopping: onStart called"); onStartCallCount++; } @Override public void onEpoch(int epochNum, double score, EarlyStoppingConfiguration esConfig, MultiLayerNetwork net) { log.info("EarlyStopping: onEpoch called (epochNum={}, score={}}", epochNum, score); onEpochCallCount++; } @Override public void onCompletion(EarlyStoppingResult esResult) { log.info("EarlyStopping: onCompletion called (result: {})", esResult); onCompletionCallCount++; } } }