package org.deeplearning4j.examples.misc.earlystopping;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**Early stopping example on a subset of MNIST
* Idea: given a small subset of MNIST (1000 examples + 500 test set), conduct training and get the parameters that
* have the minimum test set loss
* This is an over-simplified example, but the principles used here should apply in more realistic cases.
*
* For further details on early stopping, see http://deeplearning4j.org/earlystopping.html
*
* @author Alex Black
*/
public class EarlyStoppingMNIST {
public static void main(String[] args) throws Exception {
//Configure network:
int nChannels = 1;
int outputNum = 10;
int batchSize = 25;
int iterations = 1;
int seed = 123;
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true).l2(0.0005)
.learningRate(0.02)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(nChannels)
.stride(1, 1)
.nOut(20).dropOut(0.5)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(2, new DenseLayer.Builder()
.nOut(500).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note in LenetMnistExample
.backprop(true).pretrain(false).build();
//Get data:
DataSetIterator mnistTrain1024 = new MnistDataSetIterator(batchSize,1024,false,true,true,12345);
DataSetIterator mnistTest512 = new MnistDataSetIterator(batchSize,512,false,false,true,12345);
String tempDir = System.getProperty("java.io.tmpdir");
String exampleDirectory = FilenameUtils.concat(tempDir, "DL4JEarlyStoppingExample/");
EarlyStoppingModelSaver saver = new LocalFileModelSaver(exampleDirectory);
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
.epochTerminationConditions(new MaxEpochsTerminationCondition(50)) //Max of 50 epochs
.evaluateEveryNEpochs(1)
.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES)) //Max of 20 minutes
.scoreCalculator(new DataSetLossCalculator(mnistTest512, true)) //Calculate test set score
.modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf,configuration,mnistTrain1024);
//Conduct early stopping training:
EarlyStoppingResult result = trainer.fit();
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
//Print score vs. epoch
Map<Integer,Double> scoreVsEpoch = result.getScoreVsEpoch();
List<Integer> list = new ArrayList<>(scoreVsEpoch.keySet());
Collections.sort(list);
System.out.println("Score vs. Epoch:");
for( Integer i : list){
System.out.println(i + "\t" + scoreVsEpoch.get(i));
}
}
}