package quickml.supervised.crossValidation.data; import com.google.common.collect.Lists; import java.util.List; import static com.google.common.base.Preconditions.checkArgument; public class FoldedData<I> implements TrainingDataCycler<I> { private final int numFolds; private final int foldsUsed; private final List<I> allData; private int currentFold; private List<I> trainingSet; private List<I> validationSet; public FoldedData(List<I> allData, int numFolds, int foldsUsed) { checkArguments(allData, numFolds, foldsUsed); this.allData = allData; this.numFolds = numFolds; this.foldsUsed = foldsUsed; this.currentFold = 0; reset(); } @Override public void reset() { currentFold = 0; setTrainingAndValidationSets(); } @Override public List<I> getTrainingSet() { return trainingSet; } @Override public List<I> getValidationSet() { return validationSet; } @Override public List<I> getAllData() { return allData; } @Override public boolean nextCycle() { currentFold++; if (hasMore()) { setTrainingAndValidationSets(); return true; } return false; } @Override public boolean hasMore() { return currentFold < foldsUsed; } private void checkArguments(List<I> allData, int numFolds, int foldsUsed) { checkArgument(allData.size() > 0, "Training set cannot be empty"); checkArgument(numFolds <= allData.size(), "Num Folds must be less than or equal to the data getSize"); checkArgument(foldsUsed <= numFolds, "Folds used must be less then or equal to the number of folds"); checkArgument(foldsUsed > 0, "Number of folds used must be greater than 0"); } private void setTrainingAndValidationSets() { trainingSet = Lists.newArrayList(); validationSet = Lists.newArrayList(); for (int i = 0; i < allData.size(); i++) { if (i % numFolds == currentFold) validationSet.add(allData.get(i)); else trainingSet.add(allData.get(i)); } } }