package org.deeplearning4j.examples.dataexamples; import org.apache.commons.io.FilenameUtils; import org.apache.http.HttpEntity; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.records.listener.impl.LogRecordListener; import org.datavec.api.split.FileSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.examples.utilities.DataUtilities; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.util.Random; /** * Created by tom hanlon on 11/7/16. * This code example is featured in this youtube video * https://www.youtube.com/watch?v=GLC8CIoHDnI * * This differs slightly from the Video Example, * The Video example had the data already downloaded * This example includes code that downloads the data * * Instructions * Downloads a directory containing a testing and a training folder * each folder has 10 directories 0-9 * in each directory are 28 * 28 grayscale pngs of handwritten digits * The training and testing directories will have directories 0-9 with * 28 * 28 PNG images of handwritten images * * The code here shows how to use a ParentPathLabelGenerator to label the images as * they are read into the RecordReader * * The pixel values are scaled to values between 0 and 1 using * ImagePreProcessingScaler * * In this example a loop steps through 3 images and prints the DataSet to * the terminal. The expected output is the 28* 28 matrix of scaled pixel values * the list with the label for that image * and a list of the label values * * This example also applies a Listener to the RecordReader that logs the path of each image read * You would not want to do this in production * The reason it is done here is to show that a handwritten image 3 (for example) * was read from directory 3, * has a matrix with the shown values * Has a label value corresponding to 3 * */ public class MnistImagePipelineExample { /** Data URL for downloading */ public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"; /** Location to save and extract the training/testing data */ public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/"); private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExample.class); public static void main(String[] args) throws Exception { /* image information 28 * 28 grayscale grayscale implies single channel */ int height = 28; int width = 28; int channels = 1; int rngseed = 123; Random randNumGen = new Random(rngseed); int batchSize = 1; int outputNum = 10; /* This class downloadData() downloads the data stores the data in java's tmpdir 15MB download compressed It will take 158MB of space when uncompressed The data can be downloaded manually here http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz */ downloadData(); // Define the File Paths File trainData = new File(DATA_PATH + "/mnist_png/training"); File testData = new File(DATA_PATH + "/mnist_png/testing"); // Define the FileSplit(PATH, ALLOWED FORMATS,random) FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); // Extract the parent path as the image label ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); // Initialize the record reader // add a listener, to extract the name recordReader.initialize(train); // The LogRecordListener will log the path of each image read // used here for information purposes, // If the whole dataset was ingested this would place 60,000 // lines in our logs // It will show up in the output with this format // o.d.a.r.l.i.LogRecordListener - Reading /tmp/mnist_png/training/4/36384.png recordReader.setListeners(new LogRecordListener()); // DataSet Iterator DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum); // Scale pixel values to 0-1 DataNormalization scaler = new ImagePreProcessingScaler(0, 1); scaler.fit(dataIter); dataIter.setPreProcessor(scaler); // In production you would loop through all the data // in this example the loop is just through 3 // images for demonstration purposes for (int i = 1; i < 3; i++) { DataSet ds = dataIter.next(); System.out.println(ds); System.out.println(dataIter.getLabels()); } } /* Everything below here has nothing to do with your RecordReader, or DataVec, or your Neural Network The classes downloadData, getMnistPNG(), and extractTarGz are for downloading and extracting the data */ private static void downloadData() throws Exception { //Create directory if required File directory = new File(DATA_PATH); if(!directory.exists()) directory.mkdir(); //Download file: String archizePath = DATA_PATH + "/mnist_png.tar.gz"; File archiveFile = new File(archizePath); String extractedPath = DATA_PATH + "mnist_png"; File extractedFile = new File(extractedPath); if( !archiveFile.exists() ){ System.out.println("Starting data download (15MB)..."); getMnistPNG(); //Extract tar.gz file to output directory DataUtilities.extractTarGz(archizePath, DATA_PATH); } else { //Assume if archive (.tar.gz) exists, then data has already been extracted System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath()); if( !extractedFile.exists()){ //Extract tar.gz file to output directory DataUtilities.extractTarGz(archizePath, DATA_PATH); } else { System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath()); } } } public static void getMnistPNG() throws IOException { String tmpDirStr = System.getProperty("java.io.tmpdir"); String archizePath = DATA_PATH + "/mnist_png.tar.gz"; if (tmpDirStr == null) { throw new IOException("System property 'java.io.tmpdir' does specify a tmp dir"); } String url = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"; File f = new File(archizePath); File dir = new File(tmpDirStr); if (!f.exists()) { HttpClientBuilder builder = HttpClientBuilder.create(); CloseableHttpClient client = builder.build(); try (CloseableHttpResponse response = client.execute(new HttpGet(url))) { HttpEntity entity = response.getEntity(); if (entity != null) { try (FileOutputStream outstream = new FileOutputStream(f)) { entity.writeTo(outstream); outstream.flush(); outstream.close(); } } } System.out.println("Data downloaded to " + f.getAbsolutePath()); } else { System.out.println("Using existing directory at " + f.getAbsolutePath()); } } }