package org.deeplearning4j.stats; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.rnn.SparkLSTMCharacterExample; import org.deeplearning4j.spark.api.stats.SparkTrainingStats; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.spark.stats.EventStats; import org.deeplearning4j.spark.stats.StatsUtils; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.util.List; import java.util.Map; import java.util.Set; /** * This example is designed to show how to use DL4J's Spark training benchmarking/debugging/timing functionality. * For details: See https://deeplearning4j.org/spark#sparkstats * * The idea with this tool is to capture statistics on various aspects of Spark training, in order to identify * and debug performance issues. * * For the sake of the example, we will be using a network configuration and data as per the SparkLSTMCharacterExample. * * * To run the example locally: Run the example as-is. The example is set up to use Spark local. * * To run the example using Spark submit (for example on a cluster): pass "-useSparkLocal false" as the application argument, * OR first modify the example by setting the field "useSparkLocal = false" * * NOTE: On some clusters without internet access, this example may fail with "Error querying NTP server" * See: https://deeplearning4j.org/spark#sparkstatsntp * * @author Alex Black */ public class TrainingStatsExample { private static final Logger log = LoggerFactory.getLogger(TrainingStatsExample.class); @Parameter(names="-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1) private boolean useSparkLocal = true; public static void main(String[] args) throws Exception { new TrainingStatsExample().entryPoint(args); } private void entryPoint(String[] args) throws Exception { //Handle command line arguments JCommander jcmdr = new JCommander(this); try{ jcmdr.parse(args); } catch(ParameterException e){ //User provides invalid input -> print the usage info jcmdr.usage(); try{ Thread.sleep(500); } catch(Exception e2){ } throw e; } //Set up network configuration: MultiLayerConfiguration config = getConfiguration(); //Set up the Spark-specific configuration int examplesPerWorker = 8; //i.e., minibatch size that each worker gets int averagingFrequency = 3; //Frequency with which parameters are averaged //Set up Spark configuration and context SparkConf sparkConf = new SparkConf(); if(useSparkLocal){ sparkConf.setMaster("local[*]"); log.info("Using Spark Local"); } sparkConf.setAppName("DL4J Spark Stats Example"); JavaSparkContext sc = new JavaSparkContext(sparkConf); //Get data. See SparkLSTMCharacterExample for details JavaRDD<DataSet> trainingData = SparkLSTMCharacterExample.getTrainingData(sc); //Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark //Here, we are using standard parameter averaging int examplesPerDataSetObject = 1; //We haven't pre-batched our data: therefore each DataSet object contains 1 example ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject) .workerPrefetchNumBatches(2) //Async prefetch 2 batches for each worker .averagingFrequency(averagingFrequency) .batchSizePerWorker(examplesPerWorker) .build(); //Create the Spark network SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc, config, tm); //*** Tell the network to collect training statistics. These will NOT be collected by default *** sparkNetwork.setCollectTrainingStats(true); //Fit for 1 epoch: sparkNetwork.fit(trainingData); //Delete the temp training files, now that we are done with them (if fitting for multiple epochs: would be re-used) tm.deleteTempFiles(sc); //Get the statistics: SparkTrainingStats stats = sparkNetwork.getSparkTrainingStats(); Set<String> statsKeySet = stats.getKeySet(); //Keys for the types of statistics log.info("--- Collected Statistics ---"); for(String s : statsKeySet){ log.info(s); } //Demo purposes: get one statistic and print it String first = statsKeySet.iterator().next(); List<EventStats> firstStatEvents = stats.getValue(first); EventStats es = firstStatEvents.get(0); log.info("Training stats example:"); log.info("Machine ID: " + es.getMachineID()); log.info("JVM ID: " + es.getJvmID()); log.info("Thread ID: " + es.getThreadID()); log.info("Start time ms: " + es.getStartTime()); log.info("Duration ms: " + es.getDurationMs()); //Export a HTML file containing charts of the various stats calculated during training StatsUtils.exportStatsAsHtml(stats, "SparkStats.html",sc); log.info("Training stats exported to {}", new File("SparkStats.html").getAbsolutePath()); log.info("****************Example finished********************"); } //Configuration for the network we will be training private static MultiLayerConfiguration getConfiguration(){ int lstmLayerSize = 200; //Number of units in each GravesLSTM layer int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters Map<Character, Integer> CHAR_TO_INT = SparkLSTMCharacterExample.getCharToInt(); int nIn = CHAR_TO_INT.size(); int nOut = CHAR_TO_INT.size(); //Set up network configuration: MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .learningRate(0.1) .updater(Updater.RMSPROP).rmsDecay(0.95) .seed(12345) .regularization(true).l2(0.001) .weightInit(WeightInit.XAVIER) .list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayerSize).activation(Activation.TANH).build()) .layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).activation(Activation.TANH).build()) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification .nIn(lstmLayerSize).nOut(nOut).build()) .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength) .pretrain(false).backprop(true) .build(); return conf; } }