package org.deeplearning4j.examples.feedforward.anomalydetection;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import java.util.*;
/**
* This example performs unsupervised anomaly detection on MNIST using a variational autoencoder, trained with a Bernoulli
* reconstruction distribution.
*
* For details on the variational autoencoder, see:
* - Kingma and Welling, 2013 - Auto-Encoding Variational Bayes - https://arxiv.org/abs/1312.6114
*
* For the use of VAEs for anomaly detection using reconstruction probability see:
* - An & Cho, 2015 - Variational Autoencoder based Anomaly Detection using Reconstruction Probability
* http://dm.snu.ac.kr/static/docs/TR/SNUDM-TR-2015-03.pdf
*
*
* Unsupervised training is performed on the entire data set at once in this example. An alternative approach would be to
* train one model for each digit.
*
* After unsupervised training, examples are scored using the VAE layer (reconstruction probability). Here, we are using the
* labels to get the examples with the highest and lowest reconstruction probabilities for each digit for plotting. In a general
* unsupervised anomaly detection situation, these labels would not be available, and hence highest/lowest probabilities
* for the entire data set would be used instead.
*
* @author Alex Black
*/
public class VaeMNISTAnomaly {
public static void main(String[] args) throws IOException {
int minibatchSize = 128;
int rngSeed = 12345;
int nEpochs = 5; //Total number of training epochs
int reconstructionNumSamples = 16; //Reconstruction probabilities are estimated using Monte-Carlo techniques; see An & Cho for details
//MNIST data for training
DataSetIterator trainIter = new MnistDataSetIterator(minibatchSize, true, rngSeed);
//Neural net configuration
Nd4j.getRandom().setSeed(rngSeed);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.learningRate(0.05)
.updater(Updater.ADAM).adamMeanDecay(0.9).adamVarDecay(0.999)
.weightInit(WeightInit.XAVIER)
.regularization(true).l2(1e-4)
.list()
.layer(0, new VariationalAutoencoder.Builder()
.activation(Activation.LEAKYRELU)
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function
//Bernoulli reconstruction distribution + sigmoid activation - for modelling binary data (or data in range 0 to 1)
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
.nIn(28 * 28) //Input size: 28x28
.nOut(32) //Size of the latent variable space: p(z|x) - 32 values
.build())
.pretrain(true).backprop(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(100));
//Fit the data (unsupervised training)
for( int i=0; i<nEpochs; i++ ){
net.fit(trainIter);
System.out.println("Finished epoch " + (i+1) + " of " + nEpochs);
}
//Perform anomaly detection on the test set, by calculating the reconstruction probability for each example
//Then add pair (reconstruction probability, INDArray data) to lists and sort by score
//This allows us to get best N and worst N digits for each digit type
DataSetIterator testIter = new MnistDataSetIterator(minibatchSize, false, rngSeed);
//Get the variational autoencoder layer:
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder vae
= (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0);
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i, new ArrayList<>());
//Iterate over the test data, calculating reconstruction probabilities
while(testIter.hasNext()){
DataSet ds = testIter.next();
INDArray features = ds.getFeatures();
INDArray labels = Nd4j.argMax(ds.getLabels(), 1); //Labels as integer indexes (from one hot), shape [minibatchSize, 1]
int nRows = features.rows();
//Calculate the log probability for reconstructions as per An & Cho
//Higher is better, lower is worse
INDArray reconstructionErrorEachExample = vae.reconstructionLogProbability(features, reconstructionNumSamples); //Shape: [minibatchSize, 1]
for( int j=0; j<nRows; j++){
INDArray example = features.getRow(j);
int label = (int)labels.getDouble(j);
double score = reconstructionErrorEachExample.getDouble(j);
listsByDigit.get(label).add(new Pair<>(score, example));
}
}
//Sort data by score, separately for each digit
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
//Negative: return highest reconstruction probabilities first -> sorted from best to worst
return -Double.compare(o1.getFirst(),o2.getFirst());
}
};
for(List<Pair<Double, INDArray>> list : listsByDigit.values()){
Collections.sort(list, c);
}
//Select the 5 best and 5 worst numbers (by reconstruction probability) for each digit
List<INDArray> best = new ArrayList<>(50);
List<INDArray> worst = new ArrayList<>(50);
List<INDArray> bestReconstruction = new ArrayList<>(50);
List<INDArray> worstReconstruction = new ArrayList<>(50);
for( int i=0; i<10; i++ ){
List<Pair<Double,INDArray>> list = listsByDigit.get(i);
for( int j=0; j<5; j++ ){
INDArray b = list.get(j).getSecond();
INDArray w = list.get(list.size()-j-1).getSecond();
INDArray pzxMeanBest = vae.preOutput(b);
INDArray reconstructionBest = vae.generateAtMeanGivenZ(pzxMeanBest);
INDArray pzxMeanWorst = vae.preOutput(w);
INDArray reconstructionWorst = vae.generateAtMeanGivenZ(pzxMeanWorst);
best.add(b);
bestReconstruction.add(reconstructionBest);
worst.add(w);
worstReconstruction.add(reconstructionWorst);
}
}
//Visualize the best and worst digits
MNISTAnomalyExample.MNISTVisualizer bestVisualizer = new MNISTAnomalyExample.MNISTVisualizer(2.0,best,"Best (Highest Rec. Prob)");
bestVisualizer.visualize();
MNISTAnomalyExample.MNISTVisualizer bestReconstructions = new MNISTAnomalyExample.MNISTVisualizer(2.0,bestReconstruction,"Best - Reconstructions");
bestReconstructions.visualize();
MNISTAnomalyExample.MNISTVisualizer worstVisualizer = new MNISTAnomalyExample.MNISTVisualizer(2.0,worst,"Worst (Lowest Rec. Prob)");
worstVisualizer.visualize();
MNISTAnomalyExample.MNISTVisualizer worstReconstructions = new MNISTAnomalyExample.MNISTVisualizer(2.0,worstReconstruction,"Worst - Reconstructions");
worstReconstructions.visualize();
}
}