package edu.cmu.minorthird.classify.experiments; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Random; import edu.cmu.minorthird.classify.Splitter; /** * Split into k separate disjoint folds, then return k train/test splits * where each train set is the union of k-1 folds, and the test set * is the k-th fold. Preserves subpopulation information. * * @author William Cohen */ public class CrossValSplitter<T> implements Splitter<T>{ private Random random; private int folds; private List<List<T>> subpops; public CrossValSplitter(Random random,int folds){ this.random=random; this.folds=folds; } public CrossValSplitter(int folds){ this(new Random(),folds); } public CrossValSplitter(){ this(new Random(),5); } public int getNumberOfFolds(){ return folds; } public void setNumberOfFolds(int k){ this.folds=k; } @Override public void split(Iterator<T> i){ subpops=new ArrayList<List<T>>(); for(Iterator<List<T>> j=new SubpopSorter<T>(random,i).subpopIterator();j.hasNext();){ subpops.add(j.next()); } } @Override public int getNumPartitions(){ return folds; } @Override public Iterator<T> getTrain(int k){ List<T> trainList=new ArrayList<T>(); for(int i=0;i<subpops.size();i++){ if(i%folds!=k){ trainList.addAll(subpops.get(i)); } } return trainList.iterator(); } @Override public Iterator<T> getTest(int k){ List<T> testList=new ArrayList<T>(); for(int i=0;i<subpops.size();i++){ if(i%folds==k){ testList.addAll(subpops.get(i)); } } return testList.iterator(); } @Override public String toString(){ return "["+folds+"-CV splitter]"; } }