package org.deeplearning4j.examples.transferlearning.vgg16;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import java.io.File;
import java.io.IOException;
/**
* Important:
* 1. Either run "EditAtBottleneckOthersFrozen" first or save a model named "MyComputationGraph.zip" based on org.deeplearning4j.transferlearning.vgg16 with block4_pool and below intact
* 2. You will need a LOT of RAM, at the very least 16G. Set max JVM heap space accordingly
*
* Here we read in an already saved model based off on org.deeplearning4j.transferlearning.vgg16 from one of our earlier runs and "finetune"
* Since we already have reasonable results with our saved off model we can be assured that there will not be any
* large disruptive gradients flowing back to wreck the carefully trained weights in the lower layers in vgg.
*
* Finetuning like this is usually done with a low learning rate and a simple SGD optimizer
* @author susaneraly on 3/6/17.
*/
public class FineTuneFromBlockFour {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(FineTuneFromBlockFour.class);
protected static final int numClasses = 5;
protected static final long seed = 12345;
private static final String featureExtractionLayer = "block4_pool";
private static final int trainPerc = 80;
private static final int batchSize = 15;
public static void main(String [] args) throws IOException {
//Import the saved model
File locationToSave = new File("MyComputationGraph.zip");
log.info("\n\nRestoring saved model...\n\n");
ComputationGraph vgg16Transfer = ModelSerializer.restoreComputationGraph(locationToSave);
//Decide on a fine tune configuration to use.
//In cases where there already exists a setting the fine tune setting will
// override the setting for all layers that are not "frozen".
// For eg. We override the learning rate and updater
// But our optimization algorithm remains unchanged (already sgd)
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.learningRate(1e-5)
.updater(Updater.SGD)
.seed(seed)
.build();
ComputationGraph vgg16FineTune = new TransferLearning.GraphBuilder(vgg16Transfer)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor(featureExtractionLayer)
.build();
log.info(vgg16FineTune.summary());
//Dataset iterators
FlowerDataSetIterator.setup(batchSize,trainPerc);
DataSetIterator trainIter = FlowerDataSetIterator.trainIterator();
DataSetIterator testIter = FlowerDataSetIterator.testIterator();
Evaluation eval;
eval = vgg16FineTune.evaluate(testIter);
log.info("Eval stats BEFORE fit.....");
log.info(eval.stats() + "\n");
testIter.reset();
int iter = 0;
while(trainIter.hasNext()) {
vgg16FineTune.fit(trainIter.next());
if (iter % 10 == 0) {
log.info("Evaluate model at iter "+iter +" ....");
eval = vgg16FineTune.evaluate(testIter);
log.info(eval.stats());
testIter.reset();
}
iter++;
}
log.info("Model build complete");
//Save the model
File locationToSaveFineTune = new File("MyComputationGraphFineTune.zip");
boolean saveUpdater = false;
ModelSerializer.writeModel(vgg16FineTune, locationToSaveFineTune, saveUpdater);
log.info("Model saved");
}
}