package org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper; import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels; import org.deeplearning4j.nn.transferlearning.TransferLearningHelper; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.slf4j.Logger; import java.io.File; import java.io.IOException; /** * The TransferLearningHelper class allows users to "featurize" a dataset at specific intermediate vertices/layers of a model * This example demonstrates how to presave these * Refer to the "FitFromFeaturized" example for how to fit a model with these featurized datasets * @author susaneraly on 2/28/17. */ public class FeaturizedPreSave { private static final Logger log = org.slf4j.LoggerFactory.getLogger(FeaturizedPreSave.class); private static final int trainPerc = 80; protected static final int batchSize = 15; public static final String featurizeExtractionLayer = "fc2"; public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { //import org.deeplearning4j.transferlearning.vgg16 and print summary TrainedModelHelper modelImportHelper = new TrainedModelHelper(TrainedModels.VGG16); log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ComputationGraph vgg16 = modelImportHelper.loadModel(); log.info(vgg16.summary()); //use the TransferLearningHelper to freeze the specified vertices and below //NOTE: This is done in place! Pass in a cloned version of the model if you would prefer to not do this in place TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16, featurizeExtractionLayer); log.info(vgg16.summary()); FlowerDataSetIterator.setup(batchSize,trainPerc); DataSetIterator trainIter = FlowerDataSetIterator.trainIterator(); DataSetIterator testIter = FlowerDataSetIterator.testIterator(); int trainDataSaved = 0; while(trainIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(trainIter.next()); saveToDisk(currentFeaturized,trainDataSaved,true); trainDataSaved++; } int testDataSaved = 0; while(testIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(testIter.next()); saveToDisk(currentFeaturized,testDataSaved,false); testDataSaved++; } log.info("Finished pre saving featurized test and train data"); } public static void saveToDisk(DataSet currentFeaturized, int iterNum, boolean isTrain) { File fileFolder = isTrain ? new File("trainFolder"): new File("testFolder"); if (iterNum == 0) { fileFolder.mkdirs(); } String fileName = "flowers-" + featurizeExtractionLayer + "-"; fileName += isTrain ? "train-" : "test-"; fileName += iterNum + ".bin"; currentFeaturized.save(new File(fileFolder,fileName)); log.info("Saved " + (isTrain?"train ":"test ") + "dataset #"+ iterNum); } }