/**-
* This is a DataSetIterator that is specialized for the News headlines dataset used in the TrainNews example
* It takes either the train or test set data from this data set, plus a WordVectors object generated by
* PrepareWordVector.java program and generates training data sets.<br>
* Inputs/features: variable-length time series, where each word (with unknown words removed) is represented by
* its Word2Vec vector representation.<br>
* Labels/target: a single class (representing category, i.e. 0,1,2 etc. depending on content of categories.txt
* file mentioned in TrainNews.java program.
* <p>
* Note :
* - This program is a modification of original example named SentimentExampleIterator.java
* - more details is given with each function's comments in the code
* <p>
* <b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
*/
package org.deeplearning4j.examples.recurrent.processnews;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import static org.nd4j.linalg.indexing.NDArrayIndex.*;
public class NewsIterator implements DataSetIterator {
private final WordVectors wordVectors;
private final int batchSize;
private final int vectorSize;
private final int truncateLength;
private int maxLength;
private final String dataDirectory;
private final List<Pair<String, List<String>>> categoryData = new ArrayList<>();
private int cursor = 0;
private int totalNews = 0;
private final TokenizerFactory tokenizerFactory;
private int newsPosition = 0;
private final List<String> labels;
private int currCategory = 0;
/**
* @param dataDirectory the directory of the news headlines data set
* @param wordVectors WordVectors object
* @param batchSize Size of each minibatch for training
* @param truncateLength If headline length exceed this size, it will be truncated to this size.
* @param train If true: return the training data. If false: return the testing data.
* <p>
* - initialize various class variables
* - calls populateData function to load news data in categoryData vector
* - also populates labels (i.e. category related inforamtion) in labels class variable
*/
private NewsIterator(String dataDirectory,
WordVectors wordVectors,
int batchSize,
int truncateLength,
boolean train,
TokenizerFactory tokenizerFactory) {
this.dataDirectory = dataDirectory;
this.batchSize = batchSize;
this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
this.wordVectors = wordVectors;
this.truncateLength = truncateLength;
this.tokenizerFactory = tokenizerFactory;
this.populateData(train);
this.labels = new ArrayList<>();
for (int i = 0; i < this.categoryData.size(); i++) {
this.labels.add(this.categoryData.get(i).getKey().split(",")[1]);
}
}
public static Builder Builder() {
return new Builder();
}
@Override
public DataSet next(int num) {
if (cursor >= this.totalNews) throw new NoSuchElementException();
try {
return nextDataSet(num);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private DataSet nextDataSet(int num) throws IOException {
// Loads news into news list from categoryData List along with category of each news
List<String> news = new ArrayList<>(num);
int[] category = new int[num];
for (int i = 0; i < num && cursor < totalExamples(); i++) {
if (currCategory < categoryData.size()) {
news.add(this.categoryData.get(currCategory).getValue().get(newsPosition));
category[i] = Integer.parseInt(this.categoryData.get(currCategory).getKey().split(",")[0]);
currCategory++;
cursor++;
} else {
currCategory = 0;
newsPosition++;
i--;
}
}
//Second: tokenize news and filter out unknown words
List<List<String>> allTokens = new ArrayList<>(news.size());
maxLength = 0;
for (String s : news) {
List<String> tokens = tokenizerFactory.create(s).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for (String t : tokens) {
if (wordVectors.hasWord(t)) tokensFiltered.add(t);
}
allTokens.add(tokensFiltered);
maxLength = Math.max(maxLength, tokensFiltered.size());
}
//If longest news exceeds 'truncateLength': only take the first 'truncateLength' words
//System.out.println("maxLength : " + maxLength);
if (maxLength > truncateLength) maxLength = truncateLength;
//Create data for training
//Here: we have news.size() examples of varying lengths
INDArray features = Nd4j.create(news.size(), vectorSize, maxLength);
INDArray labels = Nd4j.create(news.size(), this.categoryData.size(), maxLength); //Three labels: Crime, Politics, Bollywood
//Because we are dealing with news of different lengths and only one output at the final time step: use padding arrays
//Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
INDArray featuresMask = Nd4j.zeros(news.size(), maxLength);
INDArray labelsMask = Nd4j.zeros(news.size(), maxLength);
int[] temp = new int[2];
for (int i = 0; i < news.size(); i++) {
List<String> tokens = allTokens.get(i);
temp[0] = i;
//Get word vectors for each word in news, and put them in the training data
for (int j = 0; j < tokens.size() && j < maxLength; j++) {
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{point(i),
all(),
point(j)}, vector);
temp[1] = j;
featuresMask.putScalar(temp, 1.0);
}
int idx = category[i];
int lastIdx = Math.min(tokens.size(), maxLength);
labels.putScalar(new int[]{i, idx, lastIdx - 1}, 1.0);
labelsMask.putScalar(new int[]{i, lastIdx - 1}, 1.0);
}
DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);
return ds;
}
/**
* Used post training to load a review from a file to a features INDArray that can be passed to the network output method
*
* @param file File to load the review from
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array
* @throws IOException If file cannot be read
*/
public INDArray loadFeaturesFromFile(File file, int maxLength) throws IOException {
String news = FileUtils.readFileToString(file);
return loadFeaturesFromString(news, maxLength);
}
/**
* Used post training to convert a String to a features INDArray that can be passed to the network output method
*
* @param reviewContents Contents of the review to vectorize
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array for the given input String
*/
public INDArray loadFeaturesFromString(String reviewContents, int maxLength) {
List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for (String t : tokens) {
if (wordVectors.hasWord(t)) tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength, tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for (int j = 0; j < tokens.size() && j < maxLength; j++) {
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{point(0),
all(),
point(j)}, vector);
}
return features;
}
/*
This function loads news headlines from files stored in resources into categoryData List.
*/
private void populateData(boolean train) {
File categories = new File(this.dataDirectory + File.separator + "categories.txt");
try (BufferedReader brCategories = new BufferedReader(new FileReader(categories))) {
String temp = "";
while ((temp = brCategories.readLine()) != null) {
String curFileName = train == true ?
this.dataDirectory + File.separator + "train" + File.separator + temp.split(",")[0] + ".txt" :
this.dataDirectory + File.separator + "test" + File.separator + temp.split(",")[0] + ".txt";
File currFile = new File(curFileName);
BufferedReader currBR = new BufferedReader((new FileReader(currFile)));
String tempCurrLine = "";
List<String> tempList = new ArrayList<>();
while ((tempCurrLine = currBR.readLine()) != null) {
tempList.add(tempCurrLine);
this.totalNews++;
}
currBR.close();
Pair<String, List<String>> tempPair = Pair.of(temp, tempList);
this.categoryData.add(tempPair);
}
brCategories.close();
} catch (Exception e) {
System.out.println("Exception in reading file :" + e.getMessage());
}
}
@Override
public int totalExamples() {
return this.totalNews;
}
@Override
public int inputColumns() {
return vectorSize;
}
@Override
public int totalOutcomes() {
return this.categoryData.size();
}
@Override
public void reset() {
cursor = 0;
newsPosition = 0;
currCategory = 0;
}
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public int batch() {
return batchSize;
}
@Override
public int cursor() {
return cursor;
}
@Override
public int numExamples() {
return totalExamples();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException();
}
@Override
public List<String> getLabels() {
return this.labels;
}
@Override
public boolean hasNext() {
return cursor < numExamples();
}
@Override
public DataSet next() {
return next(batchSize);
}
@Override
public void remove() {
}
@Override
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not implemented");
}
public int getMaxLength() {
return this.maxLength;
}
public static class Builder {
private String dataDirectory;
private WordVectors wordVectors;
private int batchSize;
private int truncateLength;
TokenizerFactory tokenizerFactory;
private boolean train;
Builder() {
}
public NewsIterator.Builder dataDirectory(String dataDirectory) {
this.dataDirectory = dataDirectory;
return this;
}
public NewsIterator.Builder wordVectors(WordVectors wordVectors) {
this.wordVectors = wordVectors;
return this;
}
public NewsIterator.Builder batchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
public NewsIterator.Builder truncateLength(int truncateLength) {
this.truncateLength = truncateLength;
return this;
}
public NewsIterator.Builder train(boolean train) {
this.train = train;
return this;
}
public NewsIterator.Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
return this;
}
public NewsIterator build() {
return new NewsIterator(dataDirectory,
wordVectors,
batchSize,
truncateLength,
train,
tokenizerFactory);
}
public String toString() {
return "org.deeplearning4j.examples.recurrent.ProcessNews.NewsIterator.Builder(dataDirectory=" +
this.dataDirectory + ", wordVectors=" + this.wordVectors +
", batchSize=" + this.batchSize + ", truncateLength="
+ this.truncateLength + ", train=" + this.train + ")";
}
}
}