package edu.cmu.minorthird.classify.experiments; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; import edu.cmu.minorthird.classify.HasSubpopulationId; import edu.cmu.minorthird.classify.Splitter; /** * Do N-fold cross-validation, where N is the number of different * subpopulations. * * @author William Cohen */ public class LeaveOneOutSplitter<T> implements Splitter<T>{ private Random random; private Splitter<T> crossValSplitter; public LeaveOneOutSplitter(Random random){ this.random=random; } public LeaveOneOutSplitter(){ this(new Random()); } @Override public void split(Iterator<T> i){ List<T> buf=new ArrayList<T>(); Set<String> subpops=new HashSet<String>(); while(i.hasNext()){ T t=i.next(); buf.add(t); // find subpop id, and record it String id; if(t instanceof HasSubpopulationId){ id=((HasSubpopulationId)t).getSubpopulationId(); } else{ id="youNeeekID#"+subpops.size(); } subpops.add(id); } crossValSplitter=new CrossValSplitter<T>(random,subpops.size()); crossValSplitter.split(buf.iterator()); } @Override public int getNumPartitions(){ return crossValSplitter.getNumPartitions(); } @Override public Iterator<T> getTrain(int k){ return crossValSplitter.getTrain(k); } @Override public Iterator<T> getTest(int k){ return crossValSplitter.getTest(k); } @Override public String toString(){ return "[LeaveOneOutSplitter]"; } }