package org.deeplearning4j.examples.dataexamples;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.swing.*;
import java.io.File;
import java.util.Arrays;
import java.util.List;
/**
* /**
* This code example is featured in this youtube video
*
* http://www.youtube.com/watch?v=DRHIpeJpJDI
*
* This differs slightly from the Video Example,
* The Video example had the data already downloaded
* This example includes code that downloads the data
*
* Data is downloaded from
*
*
* wget http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
* followed by tar xzvf mnist_png.tar.gz
* The Data Directory mnist_png will have two child directories training and testing
* The training and testing directories will have directories 0-9 with
* 28 * 28 PNG images of handwritten images
*
*
*
*
*
* This examples builds on the MnistImagePipelineExample
* by giving the user a file chooser to test an image of their choice
* against the Nueral Net, will the network think this cat is an 8 or a 1
* Seriously you can test anything, but obviously the network was trained on handwritten images
* 0-9 white digit, black background, so it will work better with stuff closer to what it was
* designed for
*
*/
public class MnistImagePipelineLoadChooser {
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineLoadChooser.class);
/*
Create a popup window to allow you to chose an image file to test against the
trained Neural Network
Chosen images will be automatically
scaled to 28*28 grayscale
*/
public static String fileChose(){
JFileChooser fc = new JFileChooser();
int ret = fc.showOpenDialog(null);
if (ret == JFileChooser.APPROVE_OPTION)
{
File file = fc.getSelectedFile();
String filename = file.getAbsolutePath();
return filename;
}
else {
return null;
}
}
public static void main(String[] args) throws Exception{
int height = 28;
int width = 28;
int channels = 1;
// recordReader.getLabels()
// In this version Labels are always in order
// So this is no longer needed
//List<Integer> labelList = Arrays.asList(2,3,7,1,6,4,0,5,8,9);
List<Integer> labelList = Arrays.asList(0,1,2,3,4,5,6,7,8,9);
// pop up file chooser
String filechose = fileChose().toString();
//LOAD NEURAL NETWORK
// Where to save model
File locationToSave = new File("trained_mnist_model.zip");
// Check for presence of saved model
if(locationToSave.exists()){
System.out.println("\n######Saved Model Found######\n");
}else{
System.out.println("\n\n#######File not found!#######");
System.out.println("This example depends on running ");
System.out.println("MnistImagePipelineExampleSave");
System.out.println("Run that Example First");
System.out.println("#############################\n\n");
System.exit(0);
}
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
log.info("*********TEST YOUR IMAGE AGAINST SAVED NETWORK********");
// FileChose is a string we will need a file
File file = new File(filechose);
// Use NativeImageLoader to convert to numerical matrix
NativeImageLoader loader = new NativeImageLoader(height, width, channels);
// Get the image into an INDarray
INDArray image = loader.asMatrix(file);
// 0-255
// 0-1
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.transform(image);
// Pass through to neural Net
INDArray output = model.output(image);
log.info("## The FILE CHOSEN WAS " + filechose);
log.info("## The Neural Nets Pediction ##");
log.info("## list of probabilities per label ##");
//log.info("## List of Labels in Order## ");
// In new versions labels are always in order
log.info(output.toString());
log.info(labelList.toString());
}
}