package org.deeplearning4j.examples.transferlearning.vgg16; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FeaturizedPreSave; import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIteratorFeaturized; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearningHelper; import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import java.io.IOException; /** * @author susaneraly on 3/10/17. * * Important: * Run the class "FeaturizePreSave" before attempting to run this. The outputs at the boundary of the frozen and unfrozen * vertices of a model are saved. These are referred to as "featurized" datasets in this description. * On a dataset of about 3000 images which is what is downloaded this can take "a while" * * Here we see how the transfer learning helper can be used to fit from a featurized datasets. * We attempt to train the same model architecture as the one in "EditLastLayerOthersFrozen". * Since the helper avoids the forward pass through the frozen layers we save on computation time when running multiple epochs. * In this manner, users can iterate quickly tweaking learning rates, weight initialization etc` to settle on a model that gives good results. */ public class FitFromFeaturized { private static final Logger log = org.slf4j.LoggerFactory.getLogger(FitFromFeaturized.class); public static final String featureExtractionLayer = FeaturizedPreSave.featurizeExtractionLayer; protected static final long seed = 12345; protected static final int numClasses = 5; protected static final int nEpochs = 3; public static void main(String [] args) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { //Import vgg //Note that the model imported does not have an output layer (check printed summary) // nor any training related configs (model from keras was imported with only weights and json) TrainedModelHelper modelImportHelper = new TrainedModelHelper(TrainedModels.VGG16); log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ComputationGraph vgg16 = modelImportHelper.loadModel(); log.info(vgg16.summary()); //Decide on a fine tune configuration to use. //In cases where there already exists a setting the fine tune setting will // override the setting for all layers that are not "frozen". FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .learningRate(3e-5) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS) .seed(seed) .build(); //Construct a new model with the intended architecture and print summary ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16) .fineTuneConfiguration(fineTuneConf) .setFeatureExtractor(featureExtractionLayer) //the specified layer and below are "frozen" .removeVertexKeepConnections("predictions") //replace the functionality of the final vertex .addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(4096).nOut(numClasses) .weightInit(WeightInit.DISTRIBUTION) .dist(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier .activation(Activation.SOFTMAX).build(), "fc2") .build(); log.info(vgg16Transfer.summary()); DataSetIterator trainIter = FlowerDataSetIteratorFeaturized.trainIterator(); DataSetIterator testIter = FlowerDataSetIteratorFeaturized.testIterator(); //Instantiate the transfer learning helper to fit and output from the featurized dataset //The .unfrozenGraph() is the unfrozen subset of the computation graph passed in. //If using with a UI or a listener attach them directly to the unfrozenGraph instance //With each iteration updated params from unfrozenGraph are copied over to the original model TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer); log.info(transferLearningHelper.unfrozenGraph().summary()); for (int epoch = 0; epoch < nEpochs; epoch++) { if (epoch == 0) { Evaluation eval = transferLearningHelper.unfrozenGraph().evaluate(testIter); log.info("Eval stats BEFORE fit....."); log.info(eval.stats()+"\n"); testIter.reset(); } int iter = 0; while (trainIter.hasNext()) { transferLearningHelper.fitFeaturized(trainIter.next()); if (iter % 10 == 0) { log.info("Evaluate model at iter " + iter + " ...."); Evaluation eval = transferLearningHelper.unfrozenGraph().evaluate(testIter); log.info(eval.stats()); testIter.reset(); } iter++; } trainIter.reset(); log.info("Epoch #"+epoch+" complete"); } log.info("Model build complete"); } }