package org.deeplearning4j.datasets.iterator; import lombok.Getter; import lombok.NonNull; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.Iterator; import java.util.List; /** * This wrapper provides DataSetIterator interface to existing java Iterable<DataSet> and Iterator<DataSet> * * @author raver119@gmail.com */ public class ExistingDataSetIterator implements DataSetIterator { @Getter private DataSetPreProcessor preProcessor; private transient Iterable<DataSet> iterable; private transient Iterator<DataSet> iterator; private int totalExamples = 0; private int numFeatures = 0; private int numLabels = 0; private List<String> labels; public ExistingDataSetIterator(@NonNull Iterator<DataSet> iterator) { this.iterator = iterator; } public ExistingDataSetIterator(@NonNull Iterator<DataSet> iterator, @NonNull List<String> labels) { this(iterator); this.labels = labels; } public ExistingDataSetIterator(@NonNull Iterable<DataSet> iterable) { this.iterable = iterable; this.iterator = iterable.iterator(); } public ExistingDataSetIterator(@NonNull Iterable<DataSet> iterable, @NonNull List<String> labels) { this(iterable); this.labels = labels; } public ExistingDataSetIterator(@NonNull Iterable<DataSet> iterable, int totalExamples, int numFeatures, int numLabels) { this(iterable); this.totalExamples = totalExamples; this.numFeatures = numFeatures; this.numLabels = numLabels; } @Override public DataSet next(int num) { // TODO: this might be changed throw new UnsupportedOperationException("next(int) isn't supported"); } @Override public int totalExamples() { return totalExamples; } @Override public int inputColumns() { return numFeatures; } @Override public int totalOutcomes() { if (labels != null) return labels.size(); return numLabels; } @Override public boolean resetSupported() { return iterable != null; } @Override public boolean asyncSupported() { //No need to asynchronously prefetch here: already in memory return false; } @Override public void reset() { if (iterable != null) this.iterator = iterable.iterator(); else throw new IllegalStateException( "To use reset() method you need to provide Iterable<DataSet>, not Iterator"); } @Override public int batch() { return 0; } @Override public int cursor() { return 0; } @Override public int numExamples() { return totalExamples; } @Override public void setPreProcessor(DataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; } @Override public List<String> getLabels() { return labels; } @Override public boolean hasNext() { if (iterator != null) return iterator.hasNext(); return false; } @Override public DataSet next() { if (preProcessor != null) { DataSet ds = iterator.next(); if (!ds.isPreProcessed()) { preProcessor.preProcess(ds); ds.markAsPreProcessed(); } return ds; } else return iterator.next(); } @Override public void remove() { // no-op } }