package org.deeplearning4j.examples.convolution; import org.apache.commons.io.FilenameUtils; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.transform.FlipImageTransform; import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.WarpImageTransform; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.util.Arrays; import java.util.List; import java.util.Random; import static org.bytedeco.javacpp.opencv_imgproc.COLOR_BGR2YCrCb; /** * Animal Classification * * Example classification of photos from 4 different animals (bear, duck, deer, turtle). * * References: * - U.S. Fish and Wildlife Service (animal sample dataset): http://digitalmedia.fws.gov/cdm/ * - Tiny ImageNet Classification with CNN: http://cs231n.stanford.edu/reports/2015/pdfs/leonyao_final.pdf * * CHALLENGE: Current setup gets low score results. Can you improve the scores? Some approaches: * - Add additional images to the dataset * - Apply more transforms to dataset * - Increase epochs * - Try different model configurations * - Tune by adjusting learning rate, updaters, activation & loss functions, regularization, ... */ public class AnimalsClassification { protected static final Logger log = LoggerFactory.getLogger(AnimalsClassification.class); protected static int height = 100; protected static int width = 100; protected static int channels = 3; protected static int numExamples = 80; protected static int numLabels = 4; protected static int batchSize = 20; protected static long seed = 42; protected static Random rng = new Random(seed); protected static int listenerFreq = 1; protected static int iterations = 1; protected static int epochs = 50; protected static double splitTrainTest = 0.8; protected static int nCores = 2; protected static boolean save = false; protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out public void run(String[] args) throws Exception { log.info("Load data...."); /**cd * Data Setup -> organize and limit data file paths: * - mainPath = path to image files * - fileSplit = define basic dataset split with limits on format * - pathFilter = define additional file load filter to limit size and balance batch content **/ ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/"); FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng); BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize); /** * Data Setup -> train test split * - inputSplit = define train and test split **/ InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest); InputSplit trainData = inputSplit[0]; InputSplit testData = inputSplit[1]; /** * Data Setup -> transformation * - Transform = how to tranform images and generate large dataset to train on **/ ImageTransform flipTransform1 = new FlipImageTransform(rng); ImageTransform flipTransform2 = new FlipImageTransform(new Random(123)); ImageTransform warpTransform = new WarpImageTransform(rng, 42); // ImageTransform colorTransform = new ColorConversionTransform(new Random(seed), COLOR_BGR2YCrCb); List<ImageTransform> transforms = Arrays.asList(new ImageTransform[]{flipTransform1, warpTransform, flipTransform2}); /** * Data Setup -> normalization * - how to normalize images and generate large dataset to train on **/ DataNormalization scaler = new ImagePreProcessingScaler(0, 1); log.info("Build model...."); // Uncomment below to try AlexNet. Note change height and width to at least 100 // MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init(); MultiLayerNetwork network; switch (modelType) { case "LeNet": network = lenetModel(); break; case "AlexNet": network = alexnetModel(); break; case "custom": network = customModel(); break; default: throw new InvalidInputTypeException("Incorrect model provided."); } network.init(); network.setListeners(new ScoreIterationListener(listenerFreq)); /** * Data Setup -> define how to load data into net: * - recordReader = the reader that loads and converts image data pass in inputSplit to initialize * - dataIter = a generator that only loads one batch at a time into memory to save memory * - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs **/ ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); DataSetIterator dataIter; MultipleEpochsIterator trainIter; log.info("Train model...."); // Train without transformations recordReader.initialize(trainData, null); dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels); scaler.fit(dataIter); dataIter.setPreProcessor(scaler); trainIter = new MultipleEpochsIterator(epochs, dataIter, nCores); network.fit(trainIter); // Train with transformations for (ImageTransform transform : transforms) { System.out.print("\nTraining on transformation: " + transform.getClass().toString() + "\n\n"); recordReader.initialize(trainData, transform); dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels); scaler.fit(dataIter); dataIter.setPreProcessor(scaler); trainIter = new MultipleEpochsIterator(epochs, dataIter, nCores); network.fit(trainIter); } log.info("Evaluate model...."); recordReader.initialize(testData); dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels); scaler.fit(dataIter); dataIter.setPreProcessor(scaler); Evaluation eval = network.evaluate(dataIter); log.info(eval.stats(true)); // Example on how to get predict results with trained model dataIter.reset(); DataSet testDataSet = dataIter.next(); String expectedResult = testDataSet.getLabelName(0); List<String> predict = network.predict(testDataSet); String modelResult = predict.get(0); System.out.print("\nFor a single example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n"); if (save) { log.info("Save model...."); String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "src/main/resources/"); ModelSerializer.writeModel(network, basePath + "model.bin", true); } log.info("****************Example finished********************"); } private ConvolutionLayer convInit(String name, int in, int out, int[] kernel, int[] stride, int[] pad, double bias) { return new ConvolutionLayer.Builder(kernel, stride, pad).name(name).nIn(in).nOut(out).biasInit(bias).build(); } private ConvolutionLayer conv3x3(String name, int out, double bias) { return new ConvolutionLayer.Builder(new int[]{3,3}, new int[] {1,1}, new int[] {1,1}).name(name).nOut(out).biasInit(bias).build(); } private ConvolutionLayer conv5x5(String name, int out, int[] stride, int[] pad, double bias) { return new ConvolutionLayer.Builder(new int[]{5,5}, stride, pad).name(name).nOut(out).biasInit(bias).build(); } private SubsamplingLayer maxPool(String name, int[] kernel) { return new SubsamplingLayer.Builder(kernel, new int[]{2,2}).name(name).build(); } private DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) { return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build(); } public MultiLayerNetwork lenetModel() { /** * Revisde Lenet Model approach developed by ramgo2 achieves slightly above random * Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5 **/ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .regularization(false).l2(0.005) // tried 0.0001, 0.0005 .activation(Activation.RELU) .learningRate(0.0001) // tried 0.00001, 0.00005, 0.000001 .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.RMSPROP).momentum(0.9) .list() .layer(0, convInit("cnn1", channels, 50 , new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0)) .layer(1, maxPool("maxpool1", new int[]{2,2})) .layer(2, conv5x5("cnn2", 100, new int[]{5, 5}, new int[]{1, 1}, 0)) .layer(3, maxPool("maxool2", new int[]{2,2})) .layer(4, new DenseLayer.Builder().nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(numLabels) .activation(Activation.SOFTMAX) .build()) .backprop(true).pretrain(false) .setInputType(InputType.convolutional(height, width, channels)) .build(); return new MultiLayerNetwork(conf); } public MultiLayerNetwork alexnetModel() { /** * AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks * and the imagenetExample code referenced. * http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf **/ double nonZeroBias = 1; double dropOut = 0.5; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .weightInit(WeightInit.DISTRIBUTION) .dist(new NormalDistribution(0.0, 0.01)) .activation(Activation.RELU) .updater(Updater.NESTEROVS) .iterations(iterations) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(1e-2) .biasLearningRate(1e-2*2) .learningRateDecayPolicy(LearningRatePolicy.Step) .lrPolicyDecayRate(0.1) .lrPolicySteps(100000) .regularization(true) .l2(5 * 1e-4) .momentum(0.9) .miniBatch(false) .list() .layer(0, convInit("cnn1", channels, 96, new int[]{11, 11}, new int[]{4, 4}, new int[]{3, 3}, 0)) .layer(1, new LocalResponseNormalization.Builder().name("lrn1").build()) .layer(2, maxPool("maxpool1", new int[]{3,3})) .layer(3, conv5x5("cnn2", 256, new int[] {1,1}, new int[] {2,2}, nonZeroBias)) .layer(4, new LocalResponseNormalization.Builder().name("lrn2").build()) .layer(5, maxPool("maxpool2", new int[]{3,3})) .layer(6,conv3x3("cnn3", 384, 0)) .layer(7,conv3x3("cnn4", 384, nonZeroBias)) .layer(8,conv3x3("cnn5", 256, nonZeroBias)) .layer(9, maxPool("maxpool3", new int[]{3,3})) .layer(10, fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005))) .layer(11, fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005))) .layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .name("output") .nOut(numLabels) .activation(Activation.SOFTMAX) .build()) .backprop(true) .pretrain(false) .setInputType(InputType.convolutional(height, width, channels)) .build(); return new MultiLayerNetwork(conf); } public static MultiLayerNetwork customModel() { /** * Use this method to build your own custom model. **/ return null; } public static void main(String[] args) throws Exception { new AnimalsClassification().run(args); } }