package edu.cmu.minorthird.classify.experiments;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Splitter;
/**
* Works with datasets of binary examples. Splits POS and NEG examples into
* k separate disjoints folds, separately, and then returns k train/test splits
* where each train set is the union of k-1 folds, and the test set is the k-th
* fold. Does NOT preserve subpopulation information.
*
* @author Edoardo Airoldi
* Date: Dec 8, 2003
*/
public class StratifiedCrossValSplitter implements Splitter<Example>{
// static private Logger log = Logger.getLogger(StratifiedCrossValSplitter.class);
private Random random;
private int folds;
private List<List<Example>> strata;
public StratifiedCrossValSplitter(Random random,int folds){
this.random=random;
this.folds=folds;
}
public StratifiedCrossValSplitter(int folds){
this(new Random(),folds);
}
public StratifiedCrossValSplitter(){
this(5);
}
@Override
public void split(Iterator<Example> i){
strata=new ArrayList<List<Example>>();
for(Iterator<List<Example>> j=new StrataSorter(random,i).strataIterator();j.hasNext();){
strata.add(j.next());
}
}
@Override
public int getNumPartitions(){
return folds;
}
@Override
public Iterator<Example> getTrain(int k){
List<Example> trainList=new ArrayList<Example>();
for(int i=0;i<strata.size();i++){
for(int j=0;j<strata.get(i).size();j++){
if(j%folds!=k){
trainList.add(strata.get(i).get(j));
}
}
}
return trainList.iterator();
}
@Override
public Iterator<Example> getTest(int k){
List<Example> testList=new ArrayList<Example>();
for(int i=0;i<strata.size();i++){
for(int j=0;j<strata.get(i).size();j++){
if(j%folds==k){
testList.add(strata.get(i).get(j));
}
}
}
return testList.iterator();
}
@Override
public String toString(){
return "["+folds+"-Stratified CV splitter]";
}
}