package org.deeplearning4j.examples.dataexamples;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.io.FilenameUtils;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.util.Random;
/**
* This code example is featured in this youtube video
* https://www.youtube.com/watch?v=ECA6y6ahH5E
*
** This differs slightly from the Video Example,
* The Video example had the data already downloaded
* This example includes code that downloads the data
*
*
* The Data Directory mnist_png will have two child directories training and testing
* The training and testing directories will have directories 0-9 with
* 28 * 28 PNG images of handwritten images
*
*
*
* The data is downloaded from
* wget http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
* followed by tar xzvf mnist_png.tar.gz
*
*
*
* This examples builds on the MnistImagePipelineExample
* by adding a Neural Net
*/
public class MnistImagePipelineExampleAddNeuralNet {
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleAddNeuralNet.class);
/** Data URL for downloading */
public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
/** Location to save and extract the training/testing data */
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
public static void main(String[] args) throws Exception {
// image information
// 28 * 28 grayscale
// grayscale implies single channel
int height = 28;
int width = 28;
int channels = 1;
int rngseed = 123;
Random randNumGen = new Random(rngseed);
int batchSize = 128;
int outputNum = 10;
int numEpochs = 1;
/*
This class downloadData() downloads the data
stores the data in java's tmpdir
15MB download compressed
It will take 158MB of space when uncompressed
The data can be downloaded manually here
http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
*/
downloadData();
// Define the File Paths
File trainData = new File(DATA_PATH + "/mnist_png/training");
File testData = new File(DATA_PATH + "/mnist_png/testing");
// Define the File Paths
//File trainData = new File("/tmp/mnist_png/training");
//File testData = new File("/tmp/mnist_png/testing");
// Define the FileSplit(PATH, ALLOWED FORMATS,random)
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS,randNumGen);
FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS,randNumGen);
// Extract the parent path as the image label
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
// Initialize the record reader
// add a listener, to extract the name
recordReader.initialize(train);
//recordReader.setListeners(new LogRecordListener());
// DataSet Iterator
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
// Scale pixel values to 0-1
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
// Build Our Neural Network
log.info("**** Build Model ****");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngseed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(height * width)
.nOut(100)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(100)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true)
.setInputType(InputType.convolutional(height,width,channels))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
// The Score iteration Listener will log
// output to show how well the network is training
model.setListeners(new ScoreIterationListener(10));
log.info("*****TRAIN MODEL********");
for(int i = 0; i<numEpochs; i++){
model.fit(dataIter);
}
log.info("******EVALUATE MODEL******");
recordReader.reset();
// The model trained on the training dataset split
// now that it has trained we evaluate against the
// test data of images the network has not seen
recordReader.initialize(test);
DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
scaler.fit(testIter);
testIter.setPreProcessor(scaler);
/*
log the order of the labels for later use
In previous versions the label order was consistent, but random
In current verions label order is lexicographic
preserving the RecordReader Labels order is no
longer needed left in for demonstration
purposes
*/
log.info(recordReader.getLabels().toString());
// Create Eval object with 10 possible classes
Evaluation eval = new Evaluation(outputNum);
// Evaluate the network
while(testIter.hasNext()){
DataSet next = testIter.next();
INDArray output = model.output(next.getFeatureMatrix());
// Compare the Feature Matrix from the model
// with the labels from the RecordReader
eval.eval(next.getLabels(),output);
}
log.info(eval.stats());
}
/*
Everything below here has nothing to do with your RecordReader,
or DataVec, or your Neural Network
The classes downloadData, getMnistPNG(),
and extractTarGz are for downloading and extracting the data
*/
private static void downloadData() throws Exception {
//Create directory if required
File directory = new File(DATA_PATH);
if(!directory.exists()) directory.mkdir();
//Download file:
String archizePath = DATA_PATH + "/mnist_png.tar.gz";
File archiveFile = new File(archizePath);
String extractedPath = DATA_PATH + "mnist_png";
File extractedFile = new File(extractedPath);
if( !archiveFile.exists() ){
System.out.println("Starting data download (15MB)...");
getMnistPNG();
//Extract tar.gz file to output directory
extractTarGz(archizePath, DATA_PATH);
} else {
//Assume if archive (.tar.gz) exists, then data has already been extracted
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
if( !extractedFile.exists()){
//Extract tar.gz file to output directory
extractTarGz(archizePath, DATA_PATH);
} else {
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
}
}
}
private static final int BUFFER_SIZE = 4096;
private static void extractTarGz(String filePath, String outputPath) throws IOException {
int fileCount = 0;
int dirCount = 0;
System.out.print("Extracting files");
try(TarArchiveInputStream tais = new TarArchiveInputStream(
new GzipCompressorInputStream( new BufferedInputStream( new FileInputStream(filePath))))){
TarArchiveEntry entry;
/** Read the tar entries using the getNextEntry method **/
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
//System.out.println("Extracting file: " + entry.getName());
//Create directories as required
if (entry.isDirectory()) {
new File(outputPath + entry.getName()).mkdirs();
dirCount++;
}else {
int count;
byte data[] = new byte[BUFFER_SIZE];
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
BufferedOutputStream dest = new BufferedOutputStream(fos,BUFFER_SIZE);
while ((count = tais.read(data, 0, BUFFER_SIZE)) != -1) {
dest.write(data, 0, count);
}
dest.close();
fileCount++;
}
if(fileCount % 1000 == 0) System.out.print(".");
}
}
System.out.println("\n" + fileCount + " files and " + dirCount + " directories extracted to: " + outputPath);
}
public static void getMnistPNG() throws IOException {
String tmpDirStr = System.getProperty("java.io.tmpdir");
String archizePath = DATA_PATH + "/mnist_png.tar.gz";
if (tmpDirStr == null) {
throw new IOException("System property 'java.io.tmpdir' does specify a tmp dir");
}
String url = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
File f = new File(archizePath);
File dir = new File(tmpDirStr);
if (!f.exists()) {
HttpClientBuilder builder = HttpClientBuilder.create();
CloseableHttpClient client = builder.build();
try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {
HttpEntity entity = response.getEntity();
if (entity != null) {
try (FileOutputStream outstream = new FileOutputStream(f)) {
entity.writeTo(outstream);
outstream.flush();
outstream.close();
}
}
}
System.out.println("Data downloaded to " + f.getAbsolutePath());
} else {
System.out.println("Using existing directory at " + f.getAbsolutePath());
}
}
}