package org.deeplearning4j.examples.transferlearning.vgg16; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import java.io.IOException; /** * @author susaneraly on 3/9/17. * * We use the transfer learning API to construct a new model based of org.deeplearning4j.transferlearning.vgg16 * We will hold all layers but the very last one frozen and change the number of outputs in the last layer to * match our classification task. * In other words we go from where fc2 and predictions are vertex names in org.deeplearning4j.transferlearning.vgg16 * fc2 -> predictions (1000 classes) * to * fc2 -> predictions (5 classes) * The class "FitFromFeaturized" attempts to train this same architecture the difference being the outputs from the last * frozen layer is presaved and the fit is carried out on this featurized dataset. * When running multiple epochs this can save on computation time. */ public class EditLastLayerOthersFrozen { private static final Logger log = org.slf4j.LoggerFactory.getLogger(EditLastLayerOthersFrozen.class); protected static final int numClasses = 5; protected static final long seed = 12345; private static final int trainPerc = 80; private static final int batchSize = 15; private static final String featureExtractionLayer = "fc2"; public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { //Import vgg //Note that the model imported does not have an output layer (check printed summary) // nor any training related configs (model from keras was imported with only weights and json) TrainedModelHelper modelImportHelper = new TrainedModelHelper(TrainedModels.VGG16); log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ComputationGraph vgg16 = modelImportHelper.loadModel(); log.info(vgg16.summary()); //Decide on a fine tune configuration to use. //In cases where there already exists a setting the fine tune setting will // override the setting for all layers that are not "frozen". FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .learningRate(5e-5) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS) .seed(seed) .build(); //Construct a new model with the intended architecture and print summary ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16) .fineTuneConfiguration(fineTuneConf) .setFeatureExtractor(featureExtractionLayer) //the specified layer and below are "frozen" .removeVertexKeepConnections("predictions") //replace the functionality of the final vertex .addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(4096).nOut(numClasses) .weightInit(WeightInit.DISTRIBUTION) .dist(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier .activation(Activation.SOFTMAX).build(), "fc2") .build(); log.info(vgg16Transfer.summary()); //Dataset iterators FlowerDataSetIterator.setup(batchSize,trainPerc); DataSetIterator trainIter = FlowerDataSetIterator.trainIterator(); DataSetIterator testIter = FlowerDataSetIterator.testIterator(); Evaluation eval; eval = vgg16Transfer.evaluate(testIter); log.info("Eval stats BEFORE fit....."); log.info(eval.stats() + "\n"); testIter.reset(); int iter = 0; while(trainIter.hasNext()) { vgg16Transfer.fit(trainIter.next()); if (iter % 10 == 0) { log.info("Evaluate model at iter "+iter +" ...."); eval = vgg16Transfer.evaluate(testIter); log.info(eval.stats()); testIter.reset(); } iter++; } log.info("Model build complete"); } }