package org.deeplearning4j.examples.feedforward.mnist; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /**A Simple Multi Layered Perceptron (MLP) applied to digit classification for * the MNIST Dataset (http://yann.lecun.com/exdb/mnist/). * * This file builds one input layer and one hidden layer. * * The input layer has input dimension of numRows*numColumns where these variables indicate the * number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit * (relu) activation function. The weights for this layer are initialized by using Xavier initialization * (https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/) * to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer. * * The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights * for this layer is also initialized using Xavier initialization. The activation function for this * layer is a softmax, which normalizes all the 10 outputs such that the normalized sums * add up to 1. The highest of these normalized values is picked as the predicted class. * */ public class MLPMnistSingleLayerExample { private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class); public static void main(String[] args) throws Exception { //number of rows and columns in the input pictures final int numRows = 28; final int numColumns = 28; int outputNum = 10; // number of output classes int batchSize = 128; // batch size for each epoch int rngSeed = 123; // random number seed for reproducibility int numEpochs = 15; // number of epochs to perform //Get the DataSetIterators: DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngSeed) //include a random seed for reproducibility // use stochastic gradient descent as an optimization algorithm .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .learningRate(0.006) //specify the learning rate .updater(Updater.NESTEROVS).momentum(0.9) //specify the rate of change of the learning rate. .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization .nIn(numRows * numColumns) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer .nIn(1000) .nOut(outputNum) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) //use backpropagation to adjust weights .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); //print the score with every 1 iteration model.setListeners(new ScoreIterationListener(1)); log.info("Train model...."); for( int i=0; i<numEpochs; i++ ){ model.fit(mnistTrain); } log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes while(mnistTest.hasNext()){ DataSet next = mnistTest.next(); INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction eval.eval(next.getLabels(), output); //check the prediction against the true class } log.info(eval.stats()); log.info("****************Example finished********************"); } }