package org.deeplearning4j.examples.recurrent.basic;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
/**
* This example trains a RNN. WHen trained we only have to put the first
* character of LEARNSTRING to the RNN, and it will recite the following chars
*
* @author Peter Grossmann
*/
public class BasicRNNExample {
// define a sentence to learn
public static final char[] LEARNSTRING = "Der Cottbuser Postkutscher putzt den Cottbuser Postkutschkasten.".toCharArray();
// a list of all possible characters
public static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<Character>();
// RNN dimensions
public static final int HIDDEN_LAYER_WIDTH = 50;
public static final int HIDDEN_LAYER_CONT = 2;
public static final Random r = new Random(7894);
public static void main(String[] args) {
// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<Character>();
for (char c : LEARNSTRING)
LEARNSTRING_CHARS.add(c);
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
// some common parameters
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.iterations(10);
builder.learningRate(0.001);
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
builder.seed(123);
builder.biasInit(0);
builder.miniBatch(false);
builder.updater(Updater.RMSPROP);
builder.weightInit(WeightInit.XAVIER);
ListBuilder listBuilder = builder.list();
// first difference, for rnns we need to use GravesLSTM.Builder
for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);
hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
// adopted activation function from GravesLSTMCharModellingExample
// seems to work well with RNNs
hiddenLayerBuilder.activation(Activation.TANH);
listBuilder.layer(i, hiddenLayerBuilder.build());
}
// we need to use RnnOutputLayer for our RNN
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
// softmax normalizes the output neurons, the sum of all outputs is 1
// this is required for our sampleFromDistribution-function
outputLayerBuilder.activation(Activation.SOFTMAX);
outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
// finish builder
listBuilder.pretrain(false);
listBuilder.backprop(true);
// create network
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
/*
* CREATE OUR TRAINING DATA
*/
// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
// SEQUENCE_POSITION
INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length);
INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length);
// loop through our sample-sentence
int samplePos = 0;
for (char currentChar : LEARNSTRING) {
// small hack: when currentChar is the last, take the first char as
// nextChar - not really required
char nextChar = LEARNSTRING[(samplePos + 1) % (LEARNSTRING.length)];
// input neuron for current-char is 1 at "samplePos"
input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
// output neuron for next-char is 1 at "samplePos"
labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
samplePos++;
}
DataSet trainingData = new DataSet(input, labels);
// some epochs
for (int epoch = 0; epoch < 100; epoch++) {
System.out.println("Epoch " + epoch);
// train the data
net.fit(trainingData);
// clear current stance from the last example
net.rnnClearPreviousState();
// put the first caracter into the rrn as an initialisation
INDArray testInit = Nd4j.zeros(LEARNSTRING_CHARS_LIST.size());
testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(LEARNSTRING[0]), 1);
// run one step -> IMPORTANT: rnnTimeStep() must be called, not
// output()
// the output shows what the net thinks what should come next
INDArray output = net.rnnTimeStep(testInit);
// now the net should guess LEARNSTRING.length mor characters
for (int j = 0; j < LEARNSTRING.length; j++) {
// first process the last output of the network to a concrete
// neuron, the neuron with the highest output cas the highest
// cance to get chosen
double[] outputProbDistribution = new double[LEARNSTRING_CHARS.size()];
for (int k = 0; k < outputProbDistribution.length; k++) {
outputProbDistribution[k] = output.getDouble(k);
}
int sampledCharacterIdx = findIndexOfHighestValue(outputProbDistribution);
// print the chosen output
System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));
// use the last output as input
INDArray nextInput = Nd4j.zeros(LEARNSTRING_CHARS_LIST.size());
nextInput.putScalar(sampledCharacterIdx, 1);
output = net.rnnTimeStep(nextInput);
}
System.out.print("\n");
}
}
private static int findIndexOfHighestValue(double[] distribution) {
int maxValueIndex = 0;
double maxValue = 0;
for (int i = 0; i < distribution.length; i++) {
if(distribution[i] > maxValue) {
maxValue = distribution[i];
maxValueIndex = i;
}
}
return maxValueIndex;
}
}