package org.deeplearning4j.mlp; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; /** * Train a simple/small MLP on MNIST data using Spark, then evaluate it on the test set in a distributed manner * * Note that the network being trained here is too small to make proper use of Spark - but it shows the configuration * and evaluation used for Spark training. * * * To run the example locally: Run the example as-is. The example is set up to use Spark local by default. * NOTE: Spark local should only be used for development/testing. For data parallel training on a single machine * (for example, multi-GPU systems) instead use ParallelWrapper (which is faster than using Spark for training on a single machine). * See for example MultiGpuLenetMnistExample in dl4j-cuda-specific-examples * * To run the example using Spark submit (for example on a cluster): pass "-useSparkLocal false" as the application argument, * OR first modify the example by setting the field "useSparkLocal = false" * * @author Alex Black */ public class MnistMLPExample { private static final Logger log = LoggerFactory.getLogger(MnistMLPExample.class); @Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1) private boolean useSparkLocal = true; @Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with") private int batchSizePerWorker = 16; @Parameter(names = "-numEpochs", description = "Number of epochs for training") private int numEpochs = 15; public static void main(String[] args) throws Exception { new MnistMLPExample().entryPoint(args); } protected void entryPoint(String[] args) throws Exception { //Handle command line arguments JCommander jcmdr = new JCommander(this); try { jcmdr.parse(args); } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage(); try { Thread.sleep(500); } catch (Exception e2) { } throw e; } SparkConf sparkConf = new SparkConf(); if (useSparkLocal) { sparkConf.setMaster("local[*]"); } sparkConf.setAppName("DL4J Spark MLP Example"); JavaSparkContext sc = new JavaSparkContext(sparkConf); //Load the data into memory then parallelize //This isn't a good approach in general - but is simple to use for this example DataSetIterator iterTrain = new MnistDataSetIterator(batchSizePerWorker, true, 12345); DataSetIterator iterTest = new MnistDataSetIterator(batchSizePerWorker, true, 12345); List<DataSet> trainDataList = new ArrayList<>(); List<DataSet> testDataList = new ArrayList<>(); while (iterTrain.hasNext()) { trainDataList.add(iterTrain.next()); } while (iterTest.hasNext()) { testDataList.add(iterTest.next()); } JavaRDD<DataSet> trainData = sc.parallelize(trainDataList); JavaRDD<DataSet> testData = sc.parallelize(testDataList); //---------------------------------- //Create network configuration and conduct network training MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .activation(Activation.LEAKYRELU) .weightInit(WeightInit.XAVIER) .learningRate(0.02) .updater(Updater.NESTEROVS).momentum(0.9) .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build()) .layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX).nIn(100).nOut(10).build()) .pretrain(false).backprop(true) .build(); //Configuration for Spark training: see http://deeplearning4j.org/spark for explanation of these configuration options TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) //Each DataSet object: contains (by default) 32 examples .averagingFrequency(5) .workerPrefetchNumBatches(2) //Async prefetching: 2 examples per worker .batchSizePerWorker(batchSizePerWorker) .build(); //Create the Spark network SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm); //Execute training: for (int i = 0; i < numEpochs; i++) { sparkNet.fit(trainData); log.info("Completed Epoch {}", i); } //Perform evaluation (distributed) Evaluation evaluation = sparkNet.evaluate(testData); log.info("***** Evaluation *****"); log.info(evaluation.stats()); //Delete the temp training files, now that we are done with them tm.deleteTempFiles(sc); log.info("***** Example Complete *****"); } }