package org.deeplearning4j.examples.transferlearning.vgg16;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import java.io.File;
/**
* @author susaneraly on 3/1/17.
*
* IMPORTANT:
* 1. The forward pass on VGG16 is time consuming. Refer to "FeaturizedPreSave" and "FitFromFeaturized" for how to use presaved datasets
* 2. RAM at the very least 16G, set JVM mx heap space accordingly
*
* We use the transfer learning API to construct a new model based of org.deeplearning4j.transferlearning.vgg16.
* We keep block5_pool and below frozen
* and modify/add dense layers to form
* block5_pool -> flatten -> fc1 -> fc2 -> fc3 -> newpredictions (5 classes)
* from
* block5_pool -> flatten -> fc1 -> fc2 -> predictions (1000 classes)
*
* Note that we could presave the output out block5_pool like we do in FeaturizedPreSave + FitFromFeaturized
* Refer to those two classes for more detail
*/
public class EditAtBottleneckOthersFrozen {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(EditAtBottleneckOthersFrozen.class);
protected static final int numClasses = 5;
protected static final long seed = 12345;
private static final int trainPerc = 80;
private static final int batchSize = 15;
private static final String featureExtractionLayer = "block5_pool";
public static void main(String [] args) throws Exception {
//Import vgg
//Note that the model imported does not have an output layer (check printed summary)
// nor any training related configs (model from keras was imported with only weights and json)
TrainedModelHelper modelImportHelper = new TrainedModelHelper(TrainedModels.VGG16);
log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n");
ComputationGraph vgg16 = modelImportHelper.loadModel();
log.info(vgg16.summary());
//Decide on a fine tune configuration to use.
//In cases where there already exists a setting the fine tune setting will
// override the setting for all layers that are not "frozen".
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.activation(Activation.LEAKYRELU)
.weightInit(WeightInit.RELU)
.learningRate(5e-5)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS)
.dropOut(0.5)
.seed(seed)
.build();
//Construct a new model with the intended architecture and print summary
// Note: This architecture is constructed with the primary intent of demonstrating use of the transfer learning API,
// secondary to what might give better results
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor(featureExtractionLayer) //"block5_pool" and below are frozen
.nOutReplace("fc2",1024, WeightInit.XAVIER) //modify nOut of the "fc2" vertex
.removeVertexAndConnections("predictions") //remove the final vertex and it's connections
.addLayer("fc3",new DenseLayer.Builder().activation(Activation.TANH).nIn(1024).nOut(256).build(),"fc2") //add in a new dense layer
.addLayer("newpredictions",new OutputLayer
.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(256)
.nOut(numClasses)
.build(),"fc3") //add in a final output dense layer,
// note that learning related configurations applied on a new layer here will be honored
// In other words - these will override the finetune confs.
// For eg. activation function will be softmax not RELU
.setOutputs("newpredictions") //since we removed the output vertex and it's connections we need to specify outputs for the graph
.build();
log.info(vgg16Transfer.summary());
//Dataset iterators
FlowerDataSetIterator.setup(batchSize,trainPerc);
DataSetIterator trainIter = FlowerDataSetIterator.trainIterator();
DataSetIterator testIter = FlowerDataSetIterator.testIterator();
Evaluation eval;
eval = vgg16Transfer.evaluate(testIter);
log.info("Eval stats BEFORE fit.....");
log.info(eval.stats() + "\n");
testIter.reset();
int iter = 0;
while(trainIter.hasNext()) {
vgg16Transfer.fit(trainIter.next());
if (iter % 10 == 0) {
log.info("Evaluate model at iter "+iter +" ....");
eval = vgg16Transfer.evaluate(testIter);
log.info(eval.stats());
testIter.reset();
}
iter++;
}
log.info("Model build complete");
//Save the model
//Note that the saved model will not know which layers were frozen during training.
//Frozen models always have to specified before training.
// Models with frozen layers can be constructed in the following two ways:
// 1. .setFeatureExtractor in the transfer learning API which will always a return a new model (as seen in this example)
// 2. in place with the TransferLearningHelper constructor which will take a model, and a specific vertexname
// and freeze it and the vertices on the path from an input to it (as seen in the FeaturizePreSave class)
//The saved model can be "fine-tuned" further as in the class "FitFromFeaturized"
File locationToSave = new File("MyComputationGraph.zip");
boolean saveUpdater = false;
ModelSerializer.writeModel(vgg16Transfer, locationToSave, saveUpdater);
log.info("Model saved");
}
}