package org.deeplearning4j.examples.transferlearning.vgg16;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import java.io.IOException;
/**
* @author susaneraly on 3/9/17.
*
* We use the transfer learning API to construct a new model based of org.deeplearning4j.transferlearning.vgg16
* We will hold all layers but the very last one frozen and change the number of outputs in the last layer to
* match our classification task.
* In other words we go from where fc2 and predictions are vertex names in org.deeplearning4j.transferlearning.vgg16
* fc2 -> predictions (1000 classes)
* to
* fc2 -> predictions (5 classes)
* The class "FitFromFeaturized" attempts to train this same architecture the difference being the outputs from the last
* frozen layer is presaved and the fit is carried out on this featurized dataset.
* When running multiple epochs this can save on computation time.
*/
public class EditLastLayerOthersFrozen {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(EditLastLayerOthersFrozen.class);
protected static final int numClasses = 5;
protected static final long seed = 12345;
private static final int trainPerc = 80;
private static final int batchSize = 15;
private static final String featureExtractionLayer = "fc2";
public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
//Import vgg
//Note that the model imported does not have an output layer (check printed summary)
// nor any training related configs (model from keras was imported with only weights and json)
TrainedModelHelper modelImportHelper = new TrainedModelHelper(TrainedModels.VGG16);
log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n");
ComputationGraph vgg16 = modelImportHelper.loadModel();
log.info(vgg16.summary());
//Decide on a fine tune configuration to use.
//In cases where there already exists a setting the fine tune setting will
// override the setting for all layers that are not "frozen".
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.learningRate(5e-5)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS)
.seed(seed)
.build();
//Construct a new model with the intended architecture and print summary
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor(featureExtractionLayer) //the specified layer and below are "frozen"
.removeVertexKeepConnections("predictions") //replace the functionality of the final vertex
.addLayer("predictions",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(4096).nOut(numClasses)
.weightInit(WeightInit.DISTRIBUTION)
.dist(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
.activation(Activation.SOFTMAX).build(),
"fc2")
.build();
log.info(vgg16Transfer.summary());
//Dataset iterators
FlowerDataSetIterator.setup(batchSize,trainPerc);
DataSetIterator trainIter = FlowerDataSetIterator.trainIterator();
DataSetIterator testIter = FlowerDataSetIterator.testIterator();
Evaluation eval;
eval = vgg16Transfer.evaluate(testIter);
log.info("Eval stats BEFORE fit.....");
log.info(eval.stats() + "\n");
testIter.reset();
int iter = 0;
while(trainIter.hasNext()) {
vgg16Transfer.fit(trainIter.next());
if (iter % 10 == 0) {
log.info("Evaluate model at iter "+iter +" ....");
eval = vgg16Transfer.evaluate(testIter);
log.info(eval.stats());
testIter.reset();
}
iter++;
}
log.info("Model build complete");
}
}