package edu.cmu.minorthird.classify.experiments;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import edu.cmu.minorthird.classify.Splitter;
/**
* Variant of cross-validation in which not all training data is used.
* Splits data into k separate disjoint folds, then return k
* train/test splits where each train set a sample of the union of k-1
* folds, and the test set is the k-th fold. Preserves subpopulation
* information.
*
* @author William Cohen
*/
public class SubsamplingCrossValSplitter<T> implements Splitter<T>{
private Random random;
private int folds;
private double subsampleFraction;
private List<Iterator<T>> trainIt;
private List<Iterator<T>>testIt;
private CrossValSplitter<T> cvs;
private RandomSplitter<T> rs;
public SubsamplingCrossValSplitter(Random random,int folds,
double subsampleFraction){
this.random=random;
this.folds=folds;
this.subsampleFraction=subsampleFraction;
}
public SubsamplingCrossValSplitter(int folds,double subsampleFraction){
this(new Random(),folds,subsampleFraction);
}
public SubsamplingCrossValSplitter(){
this(5,0.5);
}
public int getNumberOfFolds(){
return folds;
}
public void setNumberOfFolds(int k){
this.folds=k;
}
public double getSubsampleFraction(){
return subsampleFraction;
}
public void setSubsampleFraction(double d){
this.subsampleFraction=d;
}
@Override
public void split(Iterator<T> i){
cvs=new CrossValSplitter<T>(random,folds);
rs=new RandomSplitter<T>(random,subsampleFraction);
cvs.split(i);
testIt=new ArrayList<Iterator<T>>(folds);
trainIt=new ArrayList<Iterator<T>>(folds);
for(int k=0;k<folds;k++){
testIt.add(cvs.getTest(k));
rs.split(cvs.getTrain(k));
trainIt.add(rs.getTrain(0));
}
}
@Override
public int getNumPartitions(){
return folds;
}
@Override
public Iterator<T> getTrain(int k){
return trainIt.get(k);
}
//public Iterator getTest(int k) { return testIt[k]; }
@Override
public Iterator<T> getTest(int k){
return cvs.getTest(k);
}
@Override
public String toString(){
return "[SubCV "+folds+";"+subsampleFraction+"]";
}
}