package org.deeplearning4j.transferlearning.vgg16; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import lombok.extern.slf4j.Slf4j; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.input.PortableDataStream; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearningHelper; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluateFlatMapFunction; import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluationReduceFunction; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.transferlearning.vgg16.dataHelpers.FeaturizedPreSave; import org.deeplearning4j.transferlearning.vgg16.dataHelpers.FlowerDataSetIteratorFeaturized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import scala.Tuple2; import java.io.IOException; import java.io.OutputStream; /** * @author susaneraly on 3/10/17. * * Important: * Run the class "FeaturizePreSave" before attempting to run this. The outputs at the boundary of the frozen and unfrozen * vertices of a model are saved. These are referred to as "featurized" datasets in this description. * On a dataset of about 3000 images which is what is downloaded this can take "a while" * * Here we see how the transfer learning helper can be used to fit from a featurized datasets. * We attempt to train the same model architecture as the one in "EditLastLayerOthersFrozen". * Since the helper avoids the forward pass through the frozen layers we save on computation time when running multiple epochs. * In this manner, users can iterate quickly tweaking learning rates, weight initialization etc` to settle on a model that gives good results. */ @Slf4j public class FitFromFeaturized { public static final String featureExtractionLayer = FeaturizedPreSave.featurizeExtractionLayer; protected static final long seed = 12345; protected static final int numClasses = 5; protected static final int nEpochs = 3; @Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1) private boolean useSparkLocal = true; @Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with") private int batchSizePerWorker = 16; @Parameter(names = "-numEpochs", description = "Number of epochs for training") private int numEpochs = 15; @Parameter(names = "-hdfsRoot", description = "The root directory for hdfs for training") private String hdfsRoot = "/tmp"; public static void main(String...args) throws Exception { new FitFromFeaturized().runMain(args); } public void runMain(String [] args) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { JCommander jcmdr = new JCommander(this); try { jcmdr.parse(args); } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage(); try { Thread.sleep(500); } catch (Exception e2) { } System.exit(1); } //Import vgg //Note that the model imported does not have an output layer (check printed summary) // nor any training related configs (model from keras was imported with only weights and json) TrainedModelHelper modelImportHelper = new TrainedModelHelper(TrainedModels.VGG16); log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ComputationGraph vgg16 = modelImportHelper.loadModel(); log.info(vgg16.summary()); //Decide on a fine tune configuration to use. //In cases where there already exists a setting the fine tune setting will // override the setting for all layers that are not "frozen". FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .learningRate(3e-5) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS) .seed(seed) .build(); //Construct a new model with the intended architecture and print summary ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16) .fineTuneConfiguration(fineTuneConf) .setFeatureExtractor(featureExtractionLayer) //the specified layer and below are "frozen" .removeVertexKeepConnections("predictions") //replace the functionality of the final vertex .addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(4096).nOut(numClasses) .weightInit(WeightInit.DISTRIBUTION) .dist(new NormalDistribution(0,0.2 * (2.0/(4096 + numClasses)))) //This weight init dist gave better results than Xavier .activation(Activation.SOFTMAX).build(), "fc2") .build(); //Instantiate the transfer learning helper to fit and output from the featurized dataset //The .unfrozenGraph() is the unfrozen subset of the computation graph passed in. //If using with a UI or a listener attach them directly to the unfrozenGraph instance //With each iteration updated params from unfrozenGraph are copied over to the original model TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer); log.info(transferLearningHelper.unfrozenGraph().summary()); //Configuration for Spark training: see http://deeplearning4j.org/spark for explanation of these configuration options TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) //Each DataSet object: contains (by default) 32 examples .averagingFrequency(5) .workerPrefetchNumBatches(2) //Async prefetching: 2 examples per worker .batchSizePerWorker(batchSizePerWorker) .build(); log.info(vgg16Transfer.summary()); SparkConf sparkConf = new SparkConf(); if(useSparkLocal) sparkConf.setMaster("local[*]"); sparkConf.setAppName("vgg16"); JavaSparkContext sc = new JavaSparkContext(sparkConf); FileSystem fs = FileSystem.get(sc.hadoopConfiguration()); SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(sc,transferLearningHelper.unfrozenGraph(),tm); DataSetIterator trainIter = FlowerDataSetIteratorFeaturized.trainIterator(); DataSetIterator testIter = FlowerDataSetIteratorFeaturized.testIterator(); System.out.println("Writing train to hdfs"); int trainCountWrote = 0; while(trainIter.hasNext()) { OutputStream os = fs.create(new Path(hdfsRoot + "/" + "train","dataset" + trainCountWrote++)); trainIter.next().save(os); os.close(); } System.out.println("Writing test to hdfs"); String testDir = hdfsRoot + "/" + "test"; int testCountWrote = 0; while(testIter.hasNext()) { OutputStream os = fs.create(new Path(testDir,"dataset" + testCountWrote++)); testIter.next().save(os); os.close(); } for (int epoch = 0; epoch < nEpochs; epoch++) { sparkComputationGraph.fit(hdfsRoot + "/train"); log.info("Epoch #" + epoch + " complete"); } JavaRDD<DataSet> data = sc.binaryFiles(testDir + "/*").map(new Function<Tuple2<String, PortableDataStream>, DataSet>() { @Override public DataSet call(Tuple2<String, PortableDataStream> v1) throws Exception { DataSet d = new DataSet(); d.load(v1._2().open()); return d; } }); IEvaluateFlatMapFunction<Evaluation> evalFn = new IEvaluateFlatMapFunction<>(sc.broadcast(vgg16.getConfiguration().toJson()), sc.broadcast(sparkComputationGraph.getNetwork().params()), batchSizePerWorker, new Evaluation(numClasses)); JavaRDD<Evaluation> evaluations = data.mapPartitions(evalFn); evaluations.reduce(new IEvaluationReduceFunction<>()); Evaluation eval = sparkComputationGraph.getNetwork().evaluate(testIter); log.info("Eval stats BEFORE fit....."); log.info(eval.stats()+"\n"); testIter.reset(); log.info("Model build complete"); } }