package org.deeplearning4j.examples.unsupervised.variational;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.examples.unsupervised.variational.plot.PlotUtil;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* A simple example of training a variational autoencoder on MNIST.
* This example intentionally has a small hidden state Z (2 values) for visualization on a 2-grid.
*
* After training, this example plots 2 things:
* 1. The MNIST digit reconstructions vs. the latent space
* 2. The latent space values for the MNIST test set, as training progresses (every N minibatches)
*
* Note that for both plots, there is a slider at the top - change this to see how the reconstructions and latent
* space changes over time.
*
* @author Alex Black
*/
public class VariationalAutoEncoderExample {
private static final Logger log = LoggerFactory.getLogger(VariationalAutoEncoderExample.class);
public static void main(String[] args) throws IOException {
int minibatchSize = 128;
int rngSeed = 12345;
int nEpochs = 20; //Total number of training epochs
//Plotting configuration
int plotEveryNMinibatches = 100; //Frequency with which to collect data for later plotting
double plotMin = -5; //Minimum values for plotting (x and y dimensions)
double plotMax = 5; //Maximum values for plotting (x and y dimensions)
int plotNumSteps = 16; //Number of steps for reconstructions, between plotMin and plotMax
//MNIST data for training
DataSetIterator trainIter = new MnistDataSetIterator(minibatchSize, true, rngSeed);
//Neural net configuration
Nd4j.getRandom().setSeed(rngSeed);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(1e-2)
.updater(Updater.RMSPROP).rmsDecay(0.95)
.weightInit(WeightInit.XAVIER)
.regularization(true).l2(1e-4)
.list()
.layer(0, new VariationalAutoencoder.Builder()
.activation(Activation.LEAKYRELU)
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
.pzxActivationFunction("identity") //p(z|data) activation function
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction())) //Bernoulli distribution for p(data|z) (binary or 0 to 1 data only)
.nIn(28 * 28) //Input size: 28x28
.nOut(2) //Size of the latent variable space: p(z|x). 2 dimensions here for plotting, use more in general
.build())
.pretrain(true).backprop(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//Get the variational autoencoder layer
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder vae
= (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0);
//Test data for plotting
DataSet testdata = new MnistDataSetIterator(10000, false, rngSeed).next();
INDArray testFeatures = testdata.getFeatures();
INDArray testLabels = testdata.getLabels();
INDArray latentSpaceGrid = getLatentSpaceGrid(plotMin, plotMax, plotNumSteps); //X/Y grid values, between plotMin and plotMax
//Lists to store data for later plotting
List<INDArray> latentSpaceVsEpoch = new ArrayList<>(nEpochs + 1);
INDArray latentSpaceValues = vae.activate(testFeatures, false); //Collect and record the latent space values before training starts
latentSpaceVsEpoch.add(latentSpaceValues);
List<INDArray> digitsGrid = new ArrayList<>();
//Perform training
int iterationCount = 0;
for (int i = 0; i < nEpochs; i++) {
log.info("Starting epoch {} of {}",(i+1),nEpochs);
while (trainIter.hasNext()) {
DataSet ds = trainIter.next();
net.fit(ds);
//Every N=100 minibatches:
// (a) collect the test set latent space values for later plotting
// (b) collect the reconstructions at each point in the grid
if (iterationCount++ % plotEveryNMinibatches == 0) {
latentSpaceValues = vae.activate(testFeatures, false);
latentSpaceVsEpoch.add(latentSpaceValues);
INDArray out = vae.generateAtMeanGivenZ(latentSpaceGrid);
digitsGrid.add(out);
}
}
trainIter.reset();
}
//Plot MNIST test set - latent space vs. iteration (every 100 minibatches by default)
PlotUtil.plotData(latentSpaceVsEpoch, testLabels, plotMin, plotMax, plotEveryNMinibatches);
//Plot reconstructions - latent space vs. grid
double imageScale = 2.0; //Increase/decrease this to zoom in on the digits
PlotUtil.MNISTLatentSpaceVisualizer v = new PlotUtil.MNISTLatentSpaceVisualizer(imageScale, digitsGrid, plotEveryNMinibatches);
v.visualize();
}
//This simply returns a 2d grid: (x,y) for x=plotMin to plotMax, and y=plotMin to plotMax
private static INDArray getLatentSpaceGrid(double plotMin, double plotMax, int plotSteps) {
INDArray data = Nd4j.create(plotSteps * plotSteps, 2);
INDArray linspaceRow = Nd4j.linspace(plotMin, plotMax, plotSteps);
for (int i = 0; i < plotSteps; i++) {
data.get(NDArrayIndex.interval(i * plotSteps, (i + 1) * plotSteps), NDArrayIndex.point(0)).assign(linspaceRow);
int yStart = plotSteps - i - 1;
data.get(NDArrayIndex.interval(yStart * plotSteps, (yStart + 1) * plotSteps), NDArrayIndex.point(1)).assign(linspaceRow.getDouble(i));
}
return data;
}
}