package org.deeplearning4j.examples.recurrent.encdec;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class EncoderDecoderLSTM {
/*
* This is a seq2seq encoder-decoder LSTM model made according to the Google's paper: [1] The model tries to predict the next dialog
* line using the provided one. It learns on the Cornell Movie Dialogs corpus. Unlike simple char RNNs this model is more sophisticated
* and theoretically, given enough time and data, can deduce facts from raw text. Your mileage may vary. This particular network
* architecture is based on AdditionRNN but changed to be used with a huge amount of possible tokens (10-40k) instead of just digits.
*
* Use the get_data.sh script to download, extract and optimize the train data. It's been only tested on Linux, it could work on OS X or
* even on Windows 10 in the Ubuntu shell.
*
* Special tokens used:
*
* <unk> - replaces any word or other token that's not in the dictionary (too rare to be included or completely unknown)
*
* <eos> - end of sentence, used only in the output to stop the processing; the model input and output length is limited by the ROW_SIZE
* constant.
*
* <go> - used only in the decoder input as the first token before the model produced anything
*
* The architecture is like this: Input => Embedding Layer => Encoder => Decoder => Output (softmax)
*
* The encoder layer produces a so called "thought vector" that contains a neurally-compressed representation of the input. Depending on
* that vector the model produces different sentences even if they start with the same token. There's one more input, connected directly
* to the decoder layer, it's used to provide the previous token of the output. For the very first output token we send a special <go>
* token there, on the next iteration we use the token that the model produced the last time. On the training stage everything is
* simple, we apriori know the desired output so the decoder input would be the same token set prepended with the <go> token and without
* the last <eos> token. Example:
*
* Input: "how" "do" "you" "do" "?"
*
* Output: "I'm" "fine" "," "thanks" "!" "<eos>"
*
* Decoder: "<go>" "I'm" "fine" "," "thanks" "!"
*
* Actually, the input is reversed as per [2], the most important words are usually in the beginning of the phrase and they would get
* more weight if supplied last (the model "forgets" tokens that were supplied "long ago", i.e. they have lesser weight than the recent
* ones). The output and decoder input sequence lengths are always equal. The input and output could be of any length (less than
* ROW_SIZE) so for purpose of batching we mask the unused part of the row. The encoder and decoder layers work sequentially. First the
* encoder creates the thought vector, that is the last activations of the layer. Those activations are then duplicated for as many time
* steps as there are elements in the output so that every output element can have its own copy of the thought vector. Then the decoder
* starts working. It receives two inputs, the thought vector made by the encoder and the token that it _should have produced_ (but
* usually it outputs something else so we have our loss metric and can compute gradients for the backward pass) on the previous step
* (or <go> for the very first step). These two vectors are simply concatenated by the merge vertex. The decoder's output goes to the
* softmax layer and that's it.
*
* The test phase is much more tricky. We don't know the decoder input because we don't know the output yet (unlike in the train phase),
* it could be anything. So we can't use methods like outputSingle() and have to do some manual work. Actually, we can but it would
* require full restarts of the entire process, it's super slow and ineffective.
*
* First, we do a single feed forward pass for the input with a single decoder element, <go>. We don't need the actual activations
* except the "thought vector". It resides in the second merge vertex input (named "dup"). So we get it and store for the entire
* response generation time. Then we put the decoder input (<go> for the first iteration) and the thought vector to the merge vertex
* inputs and feed it forward. The result goes to the decoder layer, now with rnnTimeStep() method so that the internal layer state is
* updated for the next iteration. The result is fed to the output softmax layer and then we sample it randomly (not with argMax(), it
* tends to give a lot of same tokens in a row). The resulting token is looked up in the dictionary, printed to the stdout and then it
* goes to the next iteration as the decoder input and so on until we get <eos>.
*
* To continue the training process from a specific batch number, enter it when prompted; batch numbers are printed after each processed
* macrobatch. If you've changed the minibatch size after the last launch, recalculate the number accordingly, i.e. if you doubled the
* minibatch size, specify half of the value and so on.
*
* [1] https://arxiv.org/abs/1506.05869 A Neural Conversational Model
*
* [2] https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf Sequence to Sequence Learning with
* Neural Networks
*/
private final Map<String, Double> dict = new HashMap<>();
private final Map<Double, String> revDict = new HashMap<>();
private final String CHARS = "-\\/_&" + CorpusProcessor.SPECIALS;
private List<List<Double>> corpus = new ArrayList<>();
private static final int HIDDEN_LAYER_WIDTH = 512; // this is purely empirical, affects performance and VRAM requirement
private static final int EMBEDDING_WIDTH = 128; // one-hot vectors will be embedded to more dense vectors with this width
private static final String CORPUS_FILENAME = "movie_lines.txt"; // filename of data corpus to learn
private static final String MODEL_FILENAME = "rnn_train.zip"; // filename of the model
private static final String BACKUP_MODEL_FILENAME = "rnn_train.bak.zip"; // filename of the previous version of the model (backup)
private static final int MINIBATCH_SIZE = 32;
private static final Random rnd = new Random(new Date().getTime());
private static final long SAVE_EACH_MS = TimeUnit.MINUTES.toMillis(5); // save the model with this period
private static final long TEST_EACH_MS = TimeUnit.MINUTES.toMillis(1); // test the model with this period
private static final int MAX_DICT = 20000; // this number of most frequent words will be used, unknown words (that are not in the
// dictionary) are replaced with <unk> token
private static final int TBPTT_SIZE = 25;
private static final double LEARNING_RATE = 1e-1;
private static final double RMS_DECAY = 0.95;
private static final int ROW_SIZE = 40; // maximum line length in tokens
private static final int GC_WINDOW = 2000; // delay between garbage collections, try to reduce if you run out of VRAM or increase for
// better performance
private static final int MACROBATCH_SIZE = 20; // see CorpusIterator
private ComputationGraph net;
public static void main(String[] args) throws IOException {
new EncoderDecoderLSTM().run(args);
}
private void run(String[] args) throws IOException {
Nd4j.getMemoryManager().setAutoGcWindow(GC_WINDOW);
createDictionary();
File networkFile = new File(toTempPath(MODEL_FILENAME));
int offset = 0;
if (networkFile.exists()) {
System.out.println("Loading the existing network...");
net = ModelSerializer.restoreComputationGraph(networkFile);
System.out.print("Enter d to start dialog or a number to continue training from that minibatch: ");
String input;
try (Scanner scanner = new Scanner(System.in)) {
input = scanner.nextLine();
if (input.toLowerCase().equals("d")) {
startDialog(scanner);
} else {
offset = Integer.valueOf(input);
test();
}
}
} else {
System.out.println("Creating a new network...");
createComputationGraph();
}
System.out.println("Number of parameters: " + net.numParams());
net.setListeners(new ScoreIterationListener(1));
train(networkFile, offset);
}
private void createComputationGraph() {
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.iterations(1).learningRate(LEARNING_RATE).rmsDecay(RMS_DECAY)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).miniBatch(true).updater(Updater.RMSPROP)
.weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer);
GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true).backpropType(BackpropType.Standard)
.tBPTTBackwardLength(TBPTT_SIZE).tBPTTForwardLength(TBPTT_SIZE);
graphBuilder.addInputs("inputLine", "decoderInput")
.setInputTypes(InputType.recurrent(dict.size()), InputType.recurrent(dict.size()))
.addLayer("embeddingEncoder", new EmbeddingLayer.Builder().nIn(dict.size()).nOut(EMBEDDING_WIDTH).build(), "inputLine")
.addLayer("encoder",
new GravesLSTM.Builder().nIn(EMBEDDING_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH).build(),
"embeddingEncoder")
.addVertex("thoughtVector", new LastTimeStepVertex("inputLine"), "encoder")
.addVertex("dup", new DuplicateToTimeSeriesVertex("decoderInput"), "thoughtVector")
.addVertex("merge", new MergeVertex(), "decoderInput", "dup")
.addLayer("decoder",
new GravesLSTM.Builder().nIn(dict.size() + HIDDEN_LAYER_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH)
.build(),
"merge")
.addLayer("output", new RnnOutputLayer.Builder().nIn(HIDDEN_LAYER_WIDTH).nOut(dict.size()).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "decoder")
.setOutputs("output");
net = new ComputationGraph(graphBuilder.build());
net.init();
}
private void train(File networkFile, int offset) throws IOException {
long lastSaveTime = System.currentTimeMillis();
long lastTestTime = System.currentTimeMillis();
CorpusIterator logsIterator = new CorpusIterator(corpus, MINIBATCH_SIZE, MACROBATCH_SIZE, dict.size(), ROW_SIZE);
for (int epoch = 1; epoch < 10000; ++epoch) {
System.out.println("Epoch " + epoch);
if (epoch == 1) {
logsIterator.setCurrentBatch(offset);
} else {
logsIterator.reset();
}
int lastPerc = 0;
while (logsIterator.hasNextMacrobatch()) {
net.fit(logsIterator);
logsIterator.nextMacroBatch();
System.out.println("Batch = " + logsIterator.batch());
int newPerc = (logsIterator.batch() * 100 / logsIterator.totalBatches());
if (newPerc != lastPerc) {
System.out.println("Epoch complete: " + newPerc + "%");
lastPerc = newPerc;
}
if (System.currentTimeMillis() - lastSaveTime > SAVE_EACH_MS) {
saveModel(networkFile);
lastSaveTime = System.currentTimeMillis();
}
if (System.currentTimeMillis() - lastTestTime > TEST_EACH_MS) {
test();
lastTestTime = System.currentTimeMillis();
}
}
}
}
private void startDialog(Scanner scanner) throws IOException {
System.out.println("Dialog started.");
while (true) {
System.out.print("In> ");
// input line is appended to conform to the corpus format
String line = "1 +++$+++ u11 +++$+++ m0 +++$+++ WALTER +++$+++ " + scanner.nextLine() + "\n";
CorpusProcessor dialogProcessor = new CorpusProcessor(new ByteArrayInputStream(line.getBytes(StandardCharsets.UTF_8)), ROW_SIZE,
false) {
@Override
protected void processLine(String lastLine) {
List<String> words = new ArrayList<>();
tokenizeLine(lastLine, words, true);
List<Double> wordIdxs = new ArrayList<>();
if (wordsToIndexes(words, wordIdxs)) {
System.out.print("Got words: ");
for (Double idx : wordIdxs) {
System.out.print(revDict.get(idx) + " ");
}
System.out.println();
System.out.print("Out> ");
output(wordIdxs, true);
}
}
};
dialogProcessor.setDict(dict);
dialogProcessor.start();
}
}
private void saveModel(File networkFile) throws IOException {
System.out.println("Saving the model...");
File backup = new File(toTempPath(BACKUP_MODEL_FILENAME));
if (networkFile.exists()) {
if (backup.exists()) {
backup.delete();
}
networkFile.renameTo(backup);
}
ModelSerializer.writeModel(net, networkFile, true);
System.out.println("Done.");
}
private void test() {
System.out.println("======================== TEST ========================");
int selected = rnd.nextInt(corpus.size());
List<Double> rowIn = new ArrayList<>(corpus.get(selected));
System.out.print("In: ");
for (Double idx : rowIn) {
System.out.print(revDict.get(idx) + " ");
}
System.out.println();
System.out.print("Out: ");
output(rowIn, true);
System.out.println("====================== TEST END ======================");
}
private void output(List<Double> rowIn, boolean printUnknowns) {
net.rnnClearPreviousState();
Collections.reverse(rowIn);
INDArray in = Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0])), new int[] { 1, 1, rowIn.size() });
double[] decodeArr = new double[dict.size()];
decodeArr[2] = 1;
INDArray decode = Nd4j.create(decodeArr, new int[] { 1, dict.size(), 1 });
net.feedForward(new INDArray[] { in, decode }, false);
org.deeplearning4j.nn.layers.recurrent.GravesLSTM decoder = (org.deeplearning4j.nn.layers.recurrent.GravesLSTM) net
.getLayer("decoder");
Layer output = net.getLayer("output");
GraphVertex mergeVertex = net.getVertex("merge");
INDArray thoughtVector = mergeVertex.getInputs()[1];
for (int row = 0; row < ROW_SIZE; ++row) {
mergeVertex.setInputs(decode, thoughtVector);
INDArray merged = mergeVertex.doForward(false);
INDArray activateDec = decoder.rnnTimeStep(merged);
INDArray out = output.activate(activateDec, false);
double d = rnd.nextDouble();
double sum = 0.0;
int idx = -1;
for (int s = 0; s < out.size(1); s++) {
sum += out.getDouble(0, s, 0);
if (d <= sum) {
idx = s;
if (printUnknowns || s != 0) {
System.out.print(revDict.get((double) s) + " ");
}
break;
}
}
if (idx == 1) {
break;
}
double[] newDecodeArr = new double[dict.size()];
newDecodeArr[idx] = 1;
decode = Nd4j.create(newDecodeArr, new int[] { 1, dict.size(), 1 });
}
System.out.println();
}
private void createDictionary() throws IOException, FileNotFoundException {
double idx = 3.0;
dict.put("<unk>", 0.0);
revDict.put(0.0, "<unk>");
dict.put("<eos>", 1.0);
revDict.put(1.0, "<eos>");
dict.put("<go>", 2.0);
revDict.put(2.0, "<go>");
for (char c : CHARS.toCharArray()) {
if (!dict.containsKey(c)) {
dict.put(String.valueOf(c), idx);
revDict.put(idx, String.valueOf(c));
++idx;
}
}
System.out.println("Building the dictionary...");
CorpusProcessor corpusProcessor = new CorpusProcessor(toTempPath(CORPUS_FILENAME), ROW_SIZE, true);
corpusProcessor.start();
Map<String, Double> freqs = corpusProcessor.getFreq();
Set<String> dictSet = new TreeSet<>(); // the tokens order is preserved for TreeSet
Map<Double, Set<String>> freqMap = new TreeMap<>(new Comparator<Double>() {
@Override
public int compare(Double o1, Double o2) {
return (int) (o2 - o1);
}
}); // tokens of the same frequency fall under the same key, the order is reversed so the most frequent tokens go first
for (Entry<String, Double> entry : freqs.entrySet()) {
Set<String> set = freqMap.get(entry.getValue());
if (set == null) {
set = new TreeSet<>(); // tokens of the same frequency would be sorted alphabetically
freqMap.put(entry.getValue(), set);
}
set.add(entry.getKey());
}
int cnt = 0;
dictSet.addAll(dict.keySet());
// get most frequent tokens and put them to dictSet
for (Entry<Double, Set<String>> entry : freqMap.entrySet()) {
for (String val : entry.getValue()) {
if (dictSet.add(val) && ++cnt >= MAX_DICT) {
break;
}
}
if (cnt >= MAX_DICT) {
break;
}
}
// all of the above means that the dictionary with the same MAX_DICT constraint and made from the same source file will always be
// the same, the tokens always correspond to the same number so we don't need to save/restore the dictionary
System.out.println("Dictionary is ready, size is " + dictSet.size());
// index the dictionary and build the reverse dictionary for lookups
for (String word : dictSet) {
if (!dict.containsKey(word)) {
dict.put(word, idx);
revDict.put(idx, word);
++idx;
}
}
System.out.println("Total dictionary size is " + dict.size() + ". Processing the dataset...");
corpusProcessor = new CorpusProcessor(toTempPath(CORPUS_FILENAME), ROW_SIZE, false) {
@Override
protected void processLine(String lastLine) {
ArrayList<String> words = new ArrayList<>();
tokenizeLine(lastLine, words, true);
if (!words.isEmpty()) {
List<Double> wordIdxs = new ArrayList<>();
if (wordsToIndexes(words, wordIdxs)) {
corpus.add(wordIdxs);
}
}
}
};
corpusProcessor.setDict(dict);
corpusProcessor.start();
System.out.println("Done. Corpus size is " + corpus.size());
}
private String toTempPath(String path) {
return System.getProperty("java.io.tmpdir") + "/" + path;
}
}