/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; import junit.framework.TestCase; import junit.framework.TestSuite; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner; import edu.cmu.minorthird.classify.algorithms.linear.BalancedWinnow; import edu.cmu.minorthird.classify.algorithms.linear.KernelVotedPerceptron; import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner; import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes; import edu.cmu.minorthird.classify.algorithms.linear.PoissonLearner; import edu.cmu.minorthird.classify.algorithms.linear.VitorBalancedWinnow; import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron; import edu.cmu.minorthird.classify.algorithms.linear.Winnow; import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner; import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost; import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner; import edu.cmu.minorthird.classify.experiments.CrossValSplitter; import edu.cmu.minorthird.classify.experiments.SubsamplingCrossValSplitter; import edu.cmu.minorthird.classify.experiments.Tester; import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedNaiveBayesLearner; /** * * @author William Cohen */ public class TestPackage extends TestSuite{ private static Logger log=Logger.getLogger(TestPackage.class); public TestPackage(String name){ super(name); } public static TestSuite suite(){ TestSuite suite=new TestSuite(); // these are error rates that the learners empirically obtain // if we don't get these, something has changed---which doesn't // necessarily mean there's a bug... suite.addTest(new LearnerTest("bayesUnlabeled",new SemiSupervisedNaiveBayesLearner(),0.0,0.0)); suite.addTest(new LearnerTest("bayesExtreme",new PoissonLearner(),0.0,0.0)); suite.addTest(new LearnerTest("bayesExtreme",new NaiveBayes(),0.5,0.5)); suite.addTest(new LearnerTest("toy",new NaiveBayes(),1.0/7.0,1.0/7.0)); suite.addTest(new LearnerTest("bayes",new PoissonLearner(),1.0/7.0,1.0/7.0)); suite.addTest(new LearnerTest("toy",new BinaryBatchVersion(new VotedPerceptron()),0.0,1.0/7.0)); suite.addTest(new LearnerTest("toy",new BinaryBatchVersion(new KernelVotedPerceptron(),5),0.0,0.25)); suite.addTest(new LearnerTest("toy",new VotedPerceptron(),1.0/7.0,1.0/7.0)); suite.addTest(new LearnerTest("toy",new KernelVotedPerceptron(),0.0,0.0)); suite.addTest(new LearnerTest("toy",new Winnow(),0.0,0.0)); suite.addTest(new LearnerTest("toy",new BalancedWinnow(),1.0/7.0,0.0)); suite.addTest(new LearnerTest("toy",new VitorBalancedWinnow(),0.0,0.0)); suite.addTest(new LearnerTest("toy",new DecisionTreeLearner(5,2),1.0/7.0,1.0/7.0)); suite.addTest(new LearnerTest("toy",new KnnLearner(10),0.0,0.10)); suite.addTest(new LearnerTest("toy3",new KnnLearner(10),0.20,0.10)); suite.addTest(new LearnerTest("toy",new AdaBoost(new DecisionTreeLearner(5,2),10),1.0/7.0,1.0/7.0)); suite.addTest(new LearnerTest("num",new DecisionTreeLearner(5,2),0.05,0.10)); suite.addTest(new LearnerTest("sparseNum",new DecisionTreeLearner(5,2),0.0,0.10)); suite.addTest(new LogisticRegressionTest()); suite.addTest(new LearnerTest("toy",new SVMLearner(),0.0,0.0)); suite.addTest(new LearnerTest("toy3",new SVMLearner(),0.0,0.1)); suite.addTest(new XValTest(10,1)); suite.addTest(new XValTest(3,5)); suite.addTest(new XValTest(50,1,true)); suite.addTest(new XValTest(3,25,true)); return suite; } public static class LogisticRegressionTest extends TestCase{ public LogisticRegressionTest(){ super("doTest"); } public void doTest(){ MaxEntLearner lr=new MaxEntLearner(); Dataset data= SampleDatasets.makeLogisticRegressionData(new Random(0),1000,0.2,0.3); Classifier c=lr.batchTrain(data); double error=Tester.errorRate(c,data); assertEquals(0.415,error,0.05); } } public static class XValTest extends TestCase{ private int numSites,numPagesPerSite; private boolean subsample; public XValTest(int numSites,int numPagesPerSite){ this(numSites,numPagesPerSite,false); } public XValTest(int numSites,int numPagesPerSite,boolean subsample){ super("doTest"); this.numSites=numSites; this.numPagesPerSite=numPagesPerSite; this.subsample=subsample; } public void doTest(){ log.debug("[XValTest sites: "+numSites+" pages/site: "+numPagesPerSite+ "]"); List<Instance> list=new ArrayList<Instance>(); for(int site=1;site<=numSites;site++){ String subpop="www.site"+site+".com"; for(int page=1;page<=numPagesPerSite;page++){ MutableInstance inst=new MutableInstance("page"+page+".html",subpop); inst.addBinary(new Feature("site"+site+".page"+page)); list.add(inst); log.debug("instance: "+inst); } } int totalSize=list.size(); Splitter<Instance> splitter=null; if(subsample) splitter=new SubsamplingCrossValSplitter<Instance>(3,0.2); else splitter=new CrossValSplitter<Instance>(3); splitter.split(list.iterator()); assertEquals(3,splitter.getNumPartitions()); Set<Instance>[] train=new Set[3]; Set<Instance>[] test=new Set[3]; int totalTest=0; for(int i=0;i<3;i++){ log.debug("partition "+(i+1)+":"); train[i]=asSet(splitter.getTrain(i)); test[i]=asSet(splitter.getTest(i)); for(Iterator<Instance> j=test[i].iterator();j.hasNext();){ Instance inst=j.next(); log.debug(" test: "+inst); assertTrue(!train[i].contains(inst)); } log.debug(" -----\n "+test[i].size()+" total"); for(Iterator<Instance> j=train[i].iterator();j.hasNext();){ Instance inst=j.next(); log.debug(" train: "+inst); assertTrue(!test[i].contains(inst)); } log.debug(" -----\n "+train[i].size()+" total"); if(subsample){ assertTrue(totalSize>=(train[i].size()+test[i].size())); }else{ assertEquals(totalSize,train[i].size()+test[i].size()); } totalTest+=test[i].size(); } assertEquals(totalSize,totalTest); } private Set<Instance> asSet(Iterator<Instance> i){ Set<Instance> set=new HashSet<Instance>(); while(i.hasNext()) set.add(i.next()); return set; } } public static class LearnerTest extends TestCase{ private ClassifierLearner learner; private double expectedTestError; private double allowedVariance; private String testName; public LearnerTest(String testName,ClassifierLearner learner, double expectedTestError,double allowedVariance){ super("doTest"); this.learner=learner; this.expectedTestError=expectedTestError; this.testName=testName; this.allowedVariance=allowedVariance; } public void doTest(){ Dataset data=SampleDatasets.sampleData(testName,false); data.shuffle(new Random(0)); ClassifierTeacher teacher=new DatasetClassifierTeacher(data); Classifier c=teacher.train(learner); log.debug("classifier is "+c); System.out.println("classifier is "+c); Dataset testSet=SampleDatasets.sampleData(testName,true); double actualTestError=Tester.errorRate(c,testSet); log.debug("error of "+learner+" is "+actualTestError); System.out.println("error of "+learner+" is "+actualTestError); assertEquals(expectedTestError,actualTestError,allowedVariance+0.001); } } static public void main(String[] argv){ junit.textui.TestRunner.run(suite()); } }