package org.deeplearning4j.examples.dataexamples; import org.apache.commons.compress.archivers.tar.TarArchiveEntry; import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; import org.apache.commons.io.FilenameUtils; import org.apache.http.HttpEntity; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.util.Random; /** * This code example is featured in this youtube video * https://www.youtube.com/watch?v=ECA6y6ahH5E * ** This differs slightly from the Video Example, * The Video example had the data already downloaded * This example includes code that downloads the data * * * The Data Directory mnist_png will have two child directories training and testing * The training and testing directories will have directories 0-9 with * 28 * 28 PNG images of handwritten images * * * * The data is downloaded from * wget http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz * followed by tar xzvf mnist_png.tar.gz * * * * This examples builds on the MnistImagePipelineExample * by adding a Neural Net */ public class MnistImagePipelineExampleAddNeuralNet { private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleAddNeuralNet.class); /** Data URL for downloading */ public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"; /** Location to save and extract the training/testing data */ public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/"); public static void main(String[] args) throws Exception { // image information // 28 * 28 grayscale // grayscale implies single channel int height = 28; int width = 28; int channels = 1; int rngseed = 123; Random randNumGen = new Random(rngseed); int batchSize = 128; int outputNum = 10; int numEpochs = 1; /* This class downloadData() downloads the data stores the data in java's tmpdir 15MB download compressed It will take 158MB of space when uncompressed The data can be downloaded manually here http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz */ downloadData(); // Define the File Paths File trainData = new File(DATA_PATH + "/mnist_png/training"); File testData = new File(DATA_PATH + "/mnist_png/testing"); // Define the File Paths //File trainData = new File("/tmp/mnist_png/training"); //File testData = new File("/tmp/mnist_png/testing"); // Define the FileSplit(PATH, ALLOWED FORMATS,random) FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS,randNumGen); FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS,randNumGen); // Extract the parent path as the image label ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker); // Initialize the record reader // add a listener, to extract the name recordReader.initialize(train); //recordReader.setListeners(new LogRecordListener()); // DataSet Iterator DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum); // Scale pixel values to 0-1 DataNormalization scaler = new ImagePreProcessingScaler(0,1); scaler.fit(dataIter); dataIter.setPreProcessor(scaler); // Build Our Neural Network log.info("**** Build Model ****"); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngseed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .learningRate(0.006) .updater(Updater.NESTEROVS).momentum(0.9) .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder() .nIn(height * width) .nOut(100) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(100) .nOut(outputNum) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) .setInputType(InputType.convolutional(height,width,channels)) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); // The Score iteration Listener will log // output to show how well the network is training model.setListeners(new ScoreIterationListener(10)); log.info("*****TRAIN MODEL********"); for(int i = 0; i<numEpochs; i++){ model.fit(dataIter); } log.info("******EVALUATE MODEL******"); recordReader.reset(); // The model trained on the training dataset split // now that it has trained we evaluate against the // test data of images the network has not seen recordReader.initialize(test); DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum); scaler.fit(testIter); testIter.setPreProcessor(scaler); /* log the order of the labels for later use In previous versions the label order was consistent, but random In current verions label order is lexicographic preserving the RecordReader Labels order is no longer needed left in for demonstration purposes */ log.info(recordReader.getLabels().toString()); // Create Eval object with 10 possible classes Evaluation eval = new Evaluation(outputNum); // Evaluate the network while(testIter.hasNext()){ DataSet next = testIter.next(); INDArray output = model.output(next.getFeatureMatrix()); // Compare the Feature Matrix from the model // with the labels from the RecordReader eval.eval(next.getLabels(),output); } log.info(eval.stats()); } /* Everything below here has nothing to do with your RecordReader, or DataVec, or your Neural Network The classes downloadData, getMnistPNG(), and extractTarGz are for downloading and extracting the data */ private static void downloadData() throws Exception { //Create directory if required File directory = new File(DATA_PATH); if(!directory.exists()) directory.mkdir(); //Download file: String archizePath = DATA_PATH + "/mnist_png.tar.gz"; File archiveFile = new File(archizePath); String extractedPath = DATA_PATH + "mnist_png"; File extractedFile = new File(extractedPath); if( !archiveFile.exists() ){ System.out.println("Starting data download (15MB)..."); getMnistPNG(); //Extract tar.gz file to output directory extractTarGz(archizePath, DATA_PATH); } else { //Assume if archive (.tar.gz) exists, then data has already been extracted System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath()); if( !extractedFile.exists()){ //Extract tar.gz file to output directory extractTarGz(archizePath, DATA_PATH); } else { System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath()); } } } private static final int BUFFER_SIZE = 4096; private static void extractTarGz(String filePath, String outputPath) throws IOException { int fileCount = 0; int dirCount = 0; System.out.print("Extracting files"); try(TarArchiveInputStream tais = new TarArchiveInputStream( new GzipCompressorInputStream( new BufferedInputStream( new FileInputStream(filePath))))){ TarArchiveEntry entry; /** Read the tar entries using the getNextEntry method **/ while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) { //System.out.println("Extracting file: " + entry.getName()); //Create directories as required if (entry.isDirectory()) { new File(outputPath + entry.getName()).mkdirs(); dirCount++; }else { int count; byte data[] = new byte[BUFFER_SIZE]; FileOutputStream fos = new FileOutputStream(outputPath + entry.getName()); BufferedOutputStream dest = new BufferedOutputStream(fos,BUFFER_SIZE); while ((count = tais.read(data, 0, BUFFER_SIZE)) != -1) { dest.write(data, 0, count); } dest.close(); fileCount++; } if(fileCount % 1000 == 0) System.out.print("."); } } System.out.println("\n" + fileCount + " files and " + dirCount + " directories extracted to: " + outputPath); } public static void getMnistPNG() throws IOException { String tmpDirStr = System.getProperty("java.io.tmpdir"); String archizePath = DATA_PATH + "/mnist_png.tar.gz"; if (tmpDirStr == null) { throw new IOException("System property 'java.io.tmpdir' does specify a tmp dir"); } String url = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"; File f = new File(archizePath); File dir = new File(tmpDirStr); if (!f.exists()) { HttpClientBuilder builder = HttpClientBuilder.create(); CloseableHttpClient client = builder.build(); try (CloseableHttpResponse response = client.execute(new HttpGet(url))) { HttpEntity entity = response.getEntity(); if (entity != null) { try (FileOutputStream outstream = new FileOutputStream(f)) { entity.writeTo(outstream); outstream.flush(); outstream.close(); } } } System.out.println("Data downloaded to " + f.getAbsolutePath()); } else { System.out.println("Using existing directory at " + f.getAbsolutePath()); } } }