package org.deeplearning4j.examples.misc.centerloss; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.examples.unsupervised.variational.plot.PlotUtil; import org.deeplearning4j.examples.userInterface.util.GradientsAndParamsListener; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; /** * Example: training an embedding using the center loss model, on MNIST * The motivation is to use the class labels to learn embeddings that have the following properties: * (a) Intra-class similarity (i.e., similar vectors for same numbers) * (b) Inter-class dissimilarity (i.e., different vectors for different numbers) * * Refer to the paper "A Discriminative Feature Learning Approach for Deep Face Recognition", Wen et al. (2016) * http://ydwen.github.io/papers/WenECCV16.pdf * * This * * @author Alex Black */ public class CenterLossLenetMnistExample { private static final Logger log = LoggerFactory.getLogger(CenterLossLenetMnistExample.class); public static void main(String[] args) throws Exception { int outputNum = 10; // The number of possible outcomes int batchSize = 64; // Test batch size int nEpochs = 10; // Number of training epochs int seed = 123; //Lambda defines the relative strength of the center loss component. //lambda = 0.0 is equivalent to training with standard softmax only double lambda = 1.0; //Alpha can be thought of as the learning rate for the centers for each class double alpha = 0.1; log.info("Load data...."); DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); DataSetIterator mnistTest = new MnistDataSetIterator(10000, false, 12345); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .regularization(true).l2(0.0005) .learningRate(0.01) .activation(Activation.LEAKYRELU) .weightInit(WeightInit.RELU) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.ADAM).adamMeanDecay(0.9).adamVarDecay(0.999) .list() .layer(0, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(32).activation(Activation.LEAKYRELU).build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()) .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(64).build()) .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()) .layer(4, new DenseLayer.Builder().nOut(256).build()) //Layer 5 is our embedding layer: 2 dimensions, just so we can plot it on X/Y grid. Usually use more in practice .layer(5, new DenseLayer.Builder().activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).nOut(2) //Larger L2 value on the embedding layer: can help to stop the embedding layer weights // (and hence activations) from getting too large. This is especially problematic with small values of // lambda such as 0.0 .l2(0.1).build()) .layer(6, new CenterLossOutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(2).nOut(outputNum) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) //Alpha and lambda hyperparameters are specific to center loss model: see comments above and paper .alpha(alpha).lambda(lambda) .build()) .setInputType(InputType.convolutionalFlat(28, 28, 1)) .backprop(true).pretrain(false).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); log.info("Train model...."); model.setListeners(new GradientsAndParamsListener(model,100),new ScoreIterationListener(100)); List<Pair<INDArray, INDArray>> embeddingByEpoch = new ArrayList<>(); List<Integer> epochNum = new ArrayList<>(); DataSet testData = mnistTest.next(); for (int i = 0; i < nEpochs; i++) { model.fit(mnistTrain); log.info("*** Completed epoch {} ***", i); //Feed forward to the embedding layer (layer 5) to get the 2d embedding to plot later INDArray embedding = model.feedForwardToLayer(5, testData.getFeatures()).get(6); embeddingByEpoch.add(new Pair<>(embedding, testData.getLabels())); epochNum.add(i); } //Create a scatterplot: slider allows embeddings to be view at the end of each epoch PlotUtil.scatterPlot(embeddingByEpoch, epochNum, "MNIST Center Loss Embedding: l = " + lambda + ", a = " + alpha); } }