package org.deeplearning4j.examples.recurrent.seq2seq;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* Created by susaneraly on 3/27/16.
* This is class to generate a multidataset from the AdditionRNN problem
* Features of the multidataset
* - encoder input, eg. "12+13" and
* - decoder input, eg. "Go25 " for training and "Go " for test
* Labels of the multidataset
* - decoder output, "25 End"
* These strings are encoded as one hot vector sequences.
*
* Sequences generated during test are never before seen by the net
* The random number generator seed is used for repeatability so that each reset of the iterator gives the same data in the same order.
*/
public class CustomSequenceIterator implements MultiDataSetIterator {
private Random randnumG;
private final int seed;
private final int batchSize;
private final int totalBatches;
private static final int numDigits = AdditionRNN.NUM_DIGITS;
public static final int SEQ_VECTOR_DIM = AdditionRNN.FEATURE_VEC_SIZE;
public static final Map<String, Integer> oneHotMap = new HashMap<String, Integer>();
public static final String[] oneHotOrder = new String[SEQ_VECTOR_DIM];
private Set<String> seenSequences = new HashSet<String>();
private boolean toTestSet = false;
private int currentBatch = 0;
public CustomSequenceIterator(int seed, int batchSize, int totalBatches) {
this.seed = seed;
this.randnumG = new Random(seed);
this.batchSize = batchSize;
this.totalBatches = totalBatches;
oneHotEncoding();
}
public MultiDataSet generateTest(int testSize) {
toTestSet = true;
MultiDataSet testData = next(testSize);
reset();
return testData;
}
@Override
public MultiDataSet next(int sampleSize) {
INDArray encoderSeq, decoderSeq, outputSeq;
int currentCount = 0;
int num1, num2;
List<INDArray> encoderSeqList = new ArrayList<>();
List<INDArray> decoderSeqList = new ArrayList<>();
List<INDArray> outputSeqList = new ArrayList<>();
while (currentCount < sampleSize) {
while (true) {
num1 = randnumG.nextInt((int) Math.pow(10, numDigits));
num2 = randnumG.nextInt((int) Math.pow(10, numDigits));
String forSum = String.valueOf(num1) + "+" + String.valueOf(num2);
if (seenSequences.add(forSum)) {
break;
}
}
String[] encoderInput = prepToString(num1, num2);
encoderSeqList.add(mapToOneHot(encoderInput));
String[] decoderInput = prepToString(num1 + num2, true);
if (toTestSet) {
//wipe out everything after "go"; not necessary since we do not use these at test time but here for clarity
int i = 1;
while (i < decoderInput.length) {
decoderInput[i] = " ";
i++;
}
}
decoderSeqList.add(mapToOneHot(decoderInput));
String[] decoderOutput = prepToString(num1 + num2, false);
outputSeqList.add(mapToOneHot(decoderOutput));
currentCount++;
}
encoderSeq = Nd4j.vstack(encoderSeqList);
decoderSeq = Nd4j.vstack(decoderSeqList);
outputSeq = Nd4j.vstack(outputSeqList);
INDArray[] inputs = new INDArray[]{encoderSeq, decoderSeq};
INDArray[] inputMasks = new INDArray[]{Nd4j.ones(sampleSize, numDigits * 2 + 1), Nd4j.ones(sampleSize, numDigits + 1 + 1)};
INDArray[] labels = new INDArray[]{outputSeq};
INDArray[] labelMasks = new INDArray[]{Nd4j.ones(sampleSize, numDigits + 1 + 1)};
currentBatch++;
return new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, inputMasks, labelMasks);
}
@Override
public void reset() {
currentBatch = 0;
toTestSet = false;
seenSequences = new HashSet<String>();
randnumG = new Random(seed);
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public boolean hasNext() {
return currentBatch < totalBatches;
}
@Override
public MultiDataSet next() {
return next(batchSize);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Not supported");
}
public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
}
/*
Helper method for encoder input
Given two numbers, num1 and num, returns a string array which represents the input to the encoder RNN
Note that the string is padded to the correct length and reversed
Eg. num1 = 7, num 2 = 13 will return {"3","1","+","7"," "}
*/
public String[] prepToString(int num1, int num2) {
String[] encoded = new String[numDigits * 2 + 1];
String num1S = String.valueOf(num1);
String num2S = String.valueOf(num2);
//padding
while (num1S.length() < numDigits) {
num1S = " " + num1S;
}
while (num2S.length() < numDigits) {
num2S = " " + num2S;
}
String sumString = num1S + "+" + num2S;
for (int i = 0; i < encoded.length; i++) {
encoded[(encoded.length - 1) - i] = Character.toString(sumString.charAt(i));
}
return encoded;
}
/*
Helper method for decoder input when goFirst
for decoder output when !goFirst
Given a number, return a string array which represents the decoder input (or output) given goFirst (or !goFirst)
eg. For numDigits = 2 and sum = 31
if goFirst will return {"go","3","1", " "}
if !goFirst will return {"3","1"," ","eos"}
*/
public String[] prepToString(int sum, boolean goFirst) {
int start, end;
String[] decoded = new String[numDigits + 1 + 1];
if (goFirst) {
decoded[0] = "Go";
start = 1;
end = decoded.length - 1;
} else {
start = 0;
end = decoded.length - 2;
decoded[decoded.length - 1] = "End";
}
String sumString = String.valueOf(sum);
int maxIndex = start;
//add in digits
for (int i = 0; i < sumString.length(); i++) {
decoded[start + i] = Character.toString(sumString.charAt(i));
maxIndex ++;
}
//needed padding
while (maxIndex <= end) {
decoded[maxIndex] = " ";
maxIndex++;
}
return decoded;
}
/*
Takes in an array of strings and return a one hot encoded array of size 1 x 14 x timesteps
Each element in the array indicates a time step
Length of one hot vector = 14
*/
private static INDArray mapToOneHot(String[] toEncode) {
INDArray ret = Nd4j.zeros(1, SEQ_VECTOR_DIM, toEncode.length);
for (int i = 0; i < toEncode.length; i++) {
ret.putScalar(0, oneHotMap.get(toEncode[i]), i, 1);
}
return ret;
}
public static String mapToString (INDArray encodeSeq, INDArray decodeSeq) {
return mapToString(encodeSeq,decodeSeq," --> ");
}
public static String mapToString(INDArray encodeSeq, INDArray decodeSeq, String sep) {
String ret = "";
String [] encodeSeqS = oneHotDecode(encodeSeq);
String [] decodeSeqS = oneHotDecode(decodeSeq);
for (int i=0; i<encodeSeqS.length;i++) {
ret += "\t" + encodeSeqS[i] + sep +decodeSeqS[i] + "\n";
}
return ret;
}
/*
Helper method that takes in a one hot encoded INDArray and returns an interpreted array of strings
toInterpret size batchSize x one_hot_vector_size(14) x time_steps
*/
public static String[] oneHotDecode(INDArray toInterpret) {
String[] decodedString = new String[toInterpret.size(0)];
INDArray oneHotIndices = Nd4j.argMax(toInterpret, 1); //drops a dimension, so now a two dim array of shape batchSize x time_steps
for (int i = 0; i < oneHotIndices.size(0); i++) {
int[] currentSlice = oneHotIndices.slice(i).dup().data().asInt(); //each slice is a batch
decodedString[i] = mapFromOneHot(currentSlice);
}
return decodedString;
}
private static String mapFromOneHot(int[] toMap) {
String ret = "";
for (int i = 0; i < toMap.length; i++) {
ret += oneHotOrder[toMap[i]];
}
//encoder sequence, needs to be reversed
if (toMap.length > numDigits + 1 + 1) {
return new StringBuilder(ret).reverse().toString();
}
return ret;
}
/*
One hot encoding map
*/
private static void oneHotEncoding() {
for (int i = 0; i < 10; i++) {
oneHotOrder[i] = String.valueOf(i);
oneHotMap.put(String.valueOf(i), i);
}
oneHotOrder[10] = " ";
oneHotMap.put(" ", 10);
oneHotOrder[11] = "+";
oneHotMap.put("+", 11);
oneHotOrder[12] = "Go";
oneHotMap.put("Go", 12);
oneHotOrder[13] = "End";
oneHotMap.put("End", 13);
}
}