package org.deeplearning4j.examples.feedforward.anomalydetection;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
/**Example: Anomaly Detection on MNIST using simple autoencoder without pretraining
* The goal is to identify outliers digits, i.e., those digits that are unusual or
* not like the typical digits.
* This is accomplished in this example by using reconstruction error: stereotypical
* examples should have low reconstruction error, whereas outliers should have high
* reconstruction error
*
* @author Alex Black
*/
public class MNISTAnomalyExample {
public static void main(String[] args) throws Exception {
//Set up network. 784 in/out (as MNIST images are 28x28).
//784 -> 250 -> 10 -> 250 -> 784
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.iterations(1)
.weightInit(WeightInit.XAVIER)
.updater(Updater.ADAGRAD)
.activation(Activation.RELU)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(0.05)
.regularization(true).l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(1, new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(2, new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(3, new OutputLayer.Builder().nIn(250).nOut(784)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(1)));
//Load data and split into training and testing sets. 40000 train, 10000 test
DataSetIterator iter = new MnistDataSetIterator(100,50000,false);
List<INDArray> featuresTrain = new ArrayList<>();
List<INDArray> featuresTest = new ArrayList<>();
List<INDArray> labelsTest = new ArrayList<>();
Random r = new Random(12345);
while(iter.hasNext()){
DataSet ds = iter.next();
SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //80/20 split (from miniBatch = 100)
featuresTrain.add(split.getTrain().getFeatureMatrix());
DataSet dsTest = split.getTest();
featuresTest.add(dsTest.getFeatureMatrix());
INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //Convert from one-hot representation -> index
labelsTest.add(indexes);
}
//Train model:
int nEpochs = 30;
for( int epoch=0; epoch<nEpochs; epoch++ ){
for(INDArray data : featuresTrain){
net.fit(data,data);
}
System.out.println("Epoch " + epoch + " complete");
}
//Evaluate the model on the test data
//Score each example in the test set separately
//Compose a map that relates each digit to a list of (score, example) pairs
//Then find N best and N worst scores per digit
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());
for( int i=0; i<featuresTest.size(); i++ ){
INDArray testData = featuresTest.get(i);
INDArray labels = labelsTest.get(i);
int nRows = testData.rows();
for( int j=0; j<nRows; j++){
INDArray example = testData.getRow(j);
int digit = (int)labels.getDouble(j);
double score = net.score(new DataSet(example,example));
// Add (score, example) pair to the appropriate list
List digitAllPairs = listsByDigit.get(digit);
digitAllPairs.add(new ImmutablePair<>(score, example));
}
}
//Sort each list in the map by score
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
return Double.compare(o1.getLeft(),o2.getLeft());
}
};
for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
Collections.sort(digitAllPairs, c);
}
//After sorting, select N best and N worst scores (by reconstruction error) for each digit, where N=5
List<INDArray> best = new ArrayList<>(50);
List<INDArray> worst = new ArrayList<>(50);
for( int i=0; i<10; i++ ){
List<Pair<Double,INDArray>> list = listsByDigit.get(i);
for( int j=0; j<5; j++ ){
best.add(list.get(j).getRight());
worst.add(list.get(list.size()-j-1).getRight());
}
}
//Visualize the best and worst digits
MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0,best,"Best (Low Rec. Error)");
bestVisualizer.visualize();
MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0,worst,"Worst (High Rec. Error)");
worstVisualizer.visualize();
}
public static class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits; //Digits (as row vectors), one per INDArray
private String title;
private int gridWidth;
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
this(imageScale, digits, title, 5);
}
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
this.imageScale = imageScale;
this.digits = digits;
this.title = title;
this.gridWidth = gridWidth;
}
public void visualize(){
JFrame frame = new JFrame();
frame.setTitle(title);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0,gridWidth));
List<JLabel> list = getComponents();
for(JLabel image : list){
panel.add(image);
}
frame.add(panel);
frame.setVisible(true);
frame.pack();
}
private List<JLabel> getComponents(){
List<JLabel> images = new ArrayList<>();
for( INDArray arr : digits ){
BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
for( int i=0; i<784; i++ ){
bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
}