package org.deeplearning4j.examples.multigpu.w2vsentiment;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
/** This is a DataSetIterator that is specialized for the IMDB review dataset used in the Word2VecSentimentRNN example
* It takes either the train or test set data from this data set, plus a WordVectors object (typically the Google News
* 300 pretrained vectors from https://code.google.com/p/word2vec/) and generates training data sets.<br>
* Inputs/features: variable-length time series, where each word (with unknown words removed) is represented by
* its Word2Vec vector representation.<br>
* Labels/target: a single class (negative or positive), predicted at the final time step (word) of each review
*
* @author Alex Black
*/
public class SentimentExampleIterator implements DataSetIterator {
private final WordVectors wordVectors;
private final int batchSize;
private final int vectorSize;
private final int truncateLength;
private int cursor = 0;
private final File[] positiveFiles;
private final File[] negativeFiles;
private final TokenizerFactory tokenizerFactory;
/**
* @param dataDirectory the directory of the IMDB review data set
* @param wordVectors WordVectors object
* @param batchSize Size of each minibatch for training
* @param truncateLength If reviews exceed
* @param train If true: return the training data. If false: return the testing data.
*/
public SentimentExampleIterator(String dataDirectory, WordVectors wordVectors, int batchSize, int truncateLength, boolean train) throws IOException {
this.batchSize = batchSize;
this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
File p = new File(FilenameUtils.concat(dataDirectory, "aclImdb/" + (train ? "train" : "test") + "/pos/") + "/");
File n = new File(FilenameUtils.concat(dataDirectory, "aclImdb/" + (train ? "train" : "test") + "/neg/") + "/");
positiveFiles = p.listFiles();
negativeFiles = n.listFiles();
this.wordVectors = wordVectors;
this.truncateLength = truncateLength;
tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
}
@Override
public DataSet next(int num) {
if (cursor >= positiveFiles.length + negativeFiles.length) throw new NoSuchElementException();
try{
return nextDataSet(num);
}catch(IOException e){
throw new RuntimeException(e);
}
}
private DataSet nextDataSet(int num) throws IOException {
//First: load reviews to String. Alternate positive and negative reviews
List<String> reviews = new ArrayList<>(num);
boolean[] positive = new boolean[num];
for( int i=0; i<num && cursor<totalExamples(); i++ ){
if(cursor % 2 == 0){
//Load positive review
int posReviewNumber = cursor / 2;
String review = FileUtils.readFileToString(positiveFiles[posReviewNumber]);
reviews.add(review);
positive[i] = true;
} else {
//Load negative review
int negReviewNumber = cursor / 2;
String review = FileUtils.readFileToString(negativeFiles[negReviewNumber]);
reviews.add(review);
positive[i] = false;
}
cursor++;
}
//Second: tokenize reviews and filter out unknown words
List<List<String>> allTokens = new ArrayList<>(reviews.size());
int maxLength = 0;
for(String s : reviews){
List<String> tokens = tokenizerFactory.create(s).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
}
allTokens.add(tokensFiltered);
maxLength = Math.max(maxLength,tokensFiltered.size());
}
//If longest review exceeds 'truncateLength': only take the first 'truncateLength' words
if(maxLength > truncateLength) maxLength = truncateLength;
//Create data for training
//Here: we have reviews.size() examples of varying lengths
INDArray features = Nd4j.create(reviews.size(), vectorSize, maxLength);
INDArray labels = Nd4j.create(reviews.size(), 2, maxLength); //Two labels: positive or negative
//Because we are dealing with reviews of different lengths and only one output at the final time step: use padding arrays
//Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
INDArray featuresMask = Nd4j.zeros(reviews.size(), maxLength);
INDArray labelsMask = Nd4j.zeros(reviews.size(), maxLength);
int[] temp = new int[2];
for( int i=0; i<reviews.size(); i++ ){
List<String> tokens = allTokens.get(i);
temp[0] = i;
//Get word vectors for each word in review, and put them in the training data
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
temp[1] = j;
featuresMask.putScalar(temp, 1.0); //Word is present (not padding) for this example + time step -> 1.0 in features mask
}
int idx = (positive[i] ? 0 : 1);
int lastIdx = Math.min(tokens.size(),maxLength);
labels.putScalar(new int[]{i,idx,lastIdx-1},1.0); //Set label: [0,1] for negative, [1,0] for positive
labelsMask.putScalar(new int[]{i,lastIdx-1},1.0); //Specify that an output exists at the final time step for this example
}
return new DataSet(features,labels,featuresMask,labelsMask);
}
@Override
public int totalExamples() {
return positiveFiles.length + negativeFiles.length;
}
@Override
public int inputColumns() {
return vectorSize;
}
@Override
public int totalOutcomes() {
return 2;
}
@Override
public void reset() {
cursor = 0;
}
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public int batch() {
return batchSize;
}
@Override
public int cursor() {
return cursor;
}
@Override
public int numExamples() {
return totalExamples();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException();
}
@Override
public List<String> getLabels() {
return Arrays.asList("positive","negative");
}
@Override
public boolean hasNext() {
return cursor < numExamples();
}
@Override
public DataSet next() {
return next(batchSize);
}
@Override
public void remove() {
}
@Override
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not implemented");
}
/** Convenience method for loading review to String */
public String loadReviewToString(int index) throws IOException{
File f;
if(index%2 == 0) f = positiveFiles[index/2];
else f = negativeFiles[index/2];
return FileUtils.readFileToString(f);
}
/** Convenience method to get label for review */
public boolean isPositiveReview(int index){
return index%2 == 0;
}
/**
* Used post training to load a review from a file to a features INDArray that can be passed to the network output method
*
* @param file File to load the review from
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array
* @throws IOException If file cannot be read
*/
public INDArray loadFeaturesFromFile(File file, int maxLength) throws IOException {
String review = FileUtils.readFileToString(file);
return loadFeaturesFromString(review, maxLength);
}
/**
* Used post training to convert a String to a features INDArray that can be passed to the network output method
*
* @param reviewContents Contents of the review to vectorize
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array for the given input String
*/
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength,tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
}
return features;
}
}