package edu.stanford.nlp.classify; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.Triple; import edu.stanford.nlp.util.Function; import java.util.Iterator; /** * This class is meant to simplify performing cross validation on * classifiers for hyper-parameters. It has the ability to save * state for each fold (for instance, the weights for a MaxEnt * classifier, and the alphas for an SVM). * * @author Aria Haghighi * @author Jenny Finkel * @author Sarah Spikes (Templatization) */ public class CrossValidator<L, F> { private GeneralDataset<L, F> originalTrainData; private int kfold; private SavedState[] savedStates; public CrossValidator(GeneralDataset<L, F> trainData) { this (trainData,5); } public CrossValidator(GeneralDataset<L, F> trainData, int kfold) { originalTrainData = trainData; this.kfold = kfold; savedStates = new SavedState[kfold]; for (int i = 0; i < savedStates.length; i++) { savedStates[i] = new SavedState(); } } /** * Returns and Iterator over train/test/saved states */ private Iterator<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>> iterator() { return new CrossValidationIterator(); } /** * This computes the average over all folds of the function we're trying to optimize. * The input triple contains, in order, the train set, the test set, and the saved state. * You don't have to use the saved state if you don't want to. */ public double computeAverage (Function<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>,Double> function) { double sum = 0; Iterator<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>> foldIt = iterator(); while (foldIt.hasNext()) { sum += function.apply(foldIt.next()); } return sum / kfold; } class CrossValidationIterator implements Iterator<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>> { int iter = 0; public boolean hasNext() { return iter < kfold; } public void remove() { throw new RuntimeException("CrossValidationIterator doesn't support remove()"); } public Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState> next() { if (iter == kfold) return null; int start = originalTrainData.size() * iter / kfold; int end = originalTrainData.size() * (iter + 1) / kfold; //System.err.println("##train data size: " + originalTrainData.size() + " start " + start + " end " + end); Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = originalTrainData.split(start, end); return new Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,SavedState>(split.first(),split.second(),savedStates[iter++]); } } public static class SavedState { public Object state; } public static void main(String[] args) { Dataset<String, String> d = Dataset.readSVMLightFormat(args[0]); Iterator<Triple<GeneralDataset<String, String>,GeneralDataset<String, String>,SavedState>> it = (new CrossValidator<String, String>(d)).iterator(); while (it.hasNext()) { it.next(); break; } } }