package org.deeplearning4j.examples;
import java.io.IOException;
import java.util.NoSuchElementException;
import java.util.Stack;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.LocatedFileStatus;
import org.apache.hadoop.fs.Path;
import org.deeplearning4j.examples.conf.Constants;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
/**
* This class returns a MultiDataSet each time its {@link next} method is called. The data in HDFS is structured as
* follows: A main directory which contains three sub-folders: /train, /test, and /predict. Depending on the value set
* for the variable flag, this class will iterate over one of the sub-folders and create MultiDataSets. In the
* sub-folders, each HDFS file consists of rows representing time-steps (as in an LSTM) starting with a time index. Each
* file is therefore an input sequence (as in an LSTM). You must specify the time index for this to re-order the
* sequence or pad it in case some time-steps are missing.
*
* @author: Ousmane A. Dia
*/
public class MDSIterator extends BaseDataSetIterator implements MultiDataSetIterator {
private int vectorSize = 0;
private int labelSize = 0;
private final int batchSize;
private int numSteps = 6;
private Stack<Path> stack = new Stack<Path>();
private StackSequenceRecordReader ssRecordReader;
private static final long serialVersionUID = -2132071188514707198L;
/**
*
* @param dataDirectory Hadoop directory that contains either training/validation/testing data. The dataDirectory
* should be of the following form: hdfs://your_cluster:port/folder
* @param batchSize Size of the mini batch
* @param vectorSize Number of features of your embeddings
* @param labelSize Number of classes
* @param numSteps
* @param flag Takes values 0 (training, 1 (validation), 2 (testing) and helps set the start and end of your
* sequence.
*/
public MDSIterator(Configuration configuration, String dataDirectory, int batchSize, int vectorSize, int labelSize,
int numSteps, int flag) {
super(configuration, dataDirectory);
this.batchSize = batchSize;
this.vectorSize = vectorSize;
this.labelSize = labelSize;
int start = Constants.START_SEQ;
int end = Constants.END_SEQ;
start = flag == 2 ? start + 1 : start;
end = flag == 2 ? end : end - 1;
ssRecordReader = new StackSequenceRecordReader(fs, start, end);
this.numSteps = numSteps;
}
@Override
public boolean hasNext() {
try {
return hdfsIterator != null && hdfsIterator.hasNext();
} catch (IOException e) {
return false;
}
}
@Override
public MultiDataSet next() {
return next(batchSize);
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public MultiDataSet next(int num) {
try {
if (!hdfsIterator.hasNext())
throw new NoSuchElementException();
MultiDataSet mds = nextMultiDataSet(num);
while (mds == null && hdfsIterator.hasNext()) {
mds = nextMultiDataSet(num);
}
return mds;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private void pushAndClear(Path path, String index) {
String p = stack.isEmpty() ? "" : stack.peek().toUri().toString();
if (p.contains(index.split("_")[0])) {
stack.push(path);
} else {
ssRecordReader.newRecord(stack);
stack.push(path);
}
ssRecordReader.newRecord(stack);
}
private MultiDataSet nextMultiDataSet(int num) throws IOException {
String previousPath = stack.isEmpty() ? "" : stack.peek().toUri().getPath();
for (int i = 0; i < num && hdfsIterator.hasNext(); i++) {
for (int j = 0; j < numSteps; j++) {
if (!hdfsIterator.hasNext())
break;
LocatedFileStatus next = hdfsIterator.next();
Path path = next.getPath();
String currentPath = path.toUri().getPath();
String index = getRelativeFilename(currentPath);
if (previousPath.contains(index.split("_")[0])) {
if (j >= numSteps - 1 || !hdfsIterator.hasNext()) {
pushAndClear(path, index);
} else {
stack.push(path);
}
previousPath = currentPath;
} else {
if (j >= numSteps - 1 || !hdfsIterator.hasNext()) {
pushAndClear(path, index);
}
ssRecordReader.newRecord(stack);
stack.push(path);
if (!previousPath.isEmpty()) {
break;
}
previousPath = currentPath;
}
}
}
return ssRecordReader.toMultiDataSet(vectorSize, labelSize);
}
@Override
public void reset() {
super.initIterator(hdfsUrl);
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preprocessor) {
}
@Override
public void remove() {
throw new UnsupportedOperationException("Remove not yet supported");
}
}