package edu.stanford.nlp.classify;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import java.util.function.Function;
import java.util.Iterator;
import java.util.NoSuchElementException;
/**
* This class is meant to simplify performing cross validation of
* classifiers for hyper-parameters. It has the ability to save
* state for each fold (for instance, the weights for a MaxEnt
* classifier, and the alphas for an SVM).
*
* @author Aria Haghighi
* @author Jenny Finkel
* @author Sarah Spikes (Templatization)
*/
public class CrossValidator<L, F> {
private final GeneralDataset<L, F> originalTrainData;
private final int kFold;
private final SavedState[] savedStates;
public CrossValidator(GeneralDataset<L, F> trainData) {
this(trainData, 10);
}
public CrossValidator(GeneralDataset<L, F> trainData, int kFold) {
originalTrainData = trainData;
this.kFold = kFold;
savedStates = new SavedState[kFold];
for (int i = 0; i < savedStates.length; i++) {
savedStates[i] = new SavedState();
}
}
/**
* Returns an Iterator over train/test/saved states.
*
* @return An Iterator over train/test/saved states
*/
private Iterator<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>> iterator() { return new CrossValidationIterator(); }
/**
* This computes the average over all folds of the function we're trying to optimize.
* The input triple contains, in order, the train set, the test set, and the saved state.
* You don't have to use the saved state if you don't want to.
*/
public double computeAverage (Function<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>,Double> function) {
double sum = 0;
Iterator<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>> foldIt = iterator();
while (foldIt.hasNext()) {
sum += function.apply(foldIt.next());
}
return sum / kFold;
}
class CrossValidationIterator implements Iterator<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>> {
private int iter = 0;
@Override
public boolean hasNext() { return iter < kFold; }
@Override
public void remove() {
throw new UnsupportedOperationException("CrossValidationIterator doesn't support remove()");
}
@Override
public Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState> next() {
if (iter == kFold) throw new NoSuchElementException("CrossValidatorIterator exhausted.");
int start = originalTrainData.size() * iter / kFold;
int end = originalTrainData.size() * (iter + 1) / kFold;
//Logging.logger(this.getClass()).info("##train data size: " + originalTrainData.size() + " start " + start + " end " + end);
Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = originalTrainData.split(start, end);
return new Triple<>(split.first(), split.second(), savedStates[iter++]);
}
} // end class CrossValidationIterator
public static class SavedState {
public Object state;
}
public static void main(String[] args) {
Dataset<String, String> d = Dataset.readSVMLightFormat(args[0]);
Iterator<Triple<GeneralDataset<String, String>,GeneralDataset<String, String>,SavedState>> it = (new CrossValidator<>(d)).iterator();
while (it.hasNext()) {
it.next();
}
}
}