/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.text.learn;
import java.util.Iterator;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
/**
*
* @author William Cohen
*/
public class CVSplitterTest extends TestCase
{
private static Logger log = Logger.getLogger(CVSplitterTest.class);
public CVSplitterTest(String name) { super(name); }
public CVSplitterTest() { super("CVSplitterTest"); }
public static Test suite() { return new TestSuite(CVSplitterTest.class); }
public void testCV()
{
doTest(10,1);
doTest(3,5);
}
public void doTest(int numSites,int numPagesPerSite)
{
log.debug("[SpanDataXVTest sites: "+numSites+" pages/site: "+numPagesPerSite+"]");
Dataset data = new BasicDataset();
for (int site=1; site<=numSites; site++) {
String subpop = "www.site"+site+".com";
for (int page=1; page<=numPagesPerSite; page++) {
MutableInstance inst = new MutableInstance("page"+page+".html", subpop);
inst.addBinary( new Feature("site"+site+".page"+page) );
data.add(new Example(inst, ClassLabel.binaryLabel(+1)));
log.debug("instance: "+inst);
}
}
int totalSize = data.size();
Dataset.Split split = data.split(new CrossValSplitter<Example>(3));
assertEquals( 3, split.getNumPartitions() );
Dataset[] train = new Dataset[3];
Dataset[] test = new Dataset[3];
int totalTest = 0;
for (int i=0; i<3; i++) {
log.debug("partition "+(i+1)+":");
train[i] = split.getTrain(i);
test[i] = split.getTest(i);
for (Iterator<Example> j=test[i].iterator(); j.hasNext(); ) {
Example e = j.next();
log.debug(" test: "+e);
assertTrue( !contains(train[i],e) );
}
log.debug(" -----\n "+test[i].size()+" total");
for (Iterator<Example> j=train[i].iterator(); j.hasNext(); ) {
Example e = j.next();
log.debug(" train: "+e);
assertTrue( !contains(test[i],e) );
}
log.debug(" -----\n "+train[i].size()+" total");
assertEquals( totalSize, train[i].size() + test[i].size() );
totalTest += test[i].size();
}
assertEquals( totalSize, totalTest );
}
private boolean contains(Dataset data,Example e)
{
for (Iterator<Example> j=data.iterator(); j.hasNext(); ) {
Example e1 = j.next();
if (e1.asInstance()==e) return true;
}
return false;
}
}