/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.experiments;
import java.util.Iterator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.StackedDatasetClassifierTeacher;
import edu.cmu.minorthird.classify.multi.MultiClassifier;
import edu.cmu.minorthird.classify.multi.MultiDataset;
import edu.cmu.minorthird.classify.multi.MultiDatasetClassifierTeacher;
import edu.cmu.minorthird.classify.multi.MultiEvaluation;
import edu.cmu.minorthird.classify.multi.MultiExample;
import edu.cmu.minorthird.classify.relational.RealRelationalDataset;
import edu.cmu.minorthird.classify.relational.StackedBatchClassifierLearner;
import edu.cmu.minorthird.classify.relational.StackedGraphicalLearner;
import edu.cmu.minorthird.classify.semisupervised.DatasetSemiSupervisedClassifierTeacher;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedClassifier;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedClassifierLearner;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedDataset;
import edu.cmu.minorthird.classify.sequential.DatasetSequenceClassifierTeacher;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.transform.AbstractInstanceTransform;
import edu.cmu.minorthird.classify.transform.PredictedClassTransform;
import edu.cmu.minorthird.classify.transform.TransformingMultiClassifier;
import edu.cmu.minorthird.util.ProgressCounter;
/** Test a classifier, in a number of ways.
*
* @author William Cohen
*/
public class Tester
{
static private Logger log = Logger.getLogger(Tester.class);
private static final boolean DEBUG = log.getEffectiveLevel().isGreaterOrEqual( Level.DEBUG );
/** Do some sort of hold-out experiment, as determined by the splitter */
static public Evaluation evaluate(StackedBatchClassifierLearner learner,RealRelationalDataset d,Splitter<Example> splitter, String stacked)
{
Evaluation v = new Evaluation(d.getSchema());
RealRelationalDataset.Split s = d.split(splitter);
//System.out.println("Test Splitter: "+splitter);
ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
RealRelationalDataset trainData = (RealRelationalDataset)s.getTrain(k);
RealRelationalDataset testData = (RealRelationalDataset)s.getTest(k);
log.info("splitting with "+splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
Classifier c = new StackedDatasetClassifierTeacher(trainData).trainStacked(learner);
if (DEBUG) log.debug("classifier for fold "+(k+1)+"/"+s.getNumPartitions()+" is:\n" + c);
v.extend4SGM( (StackedGraphicalLearner.StackedGraphicalClassifier)c, testData, k );
log.info("splitting with "+splitter+", completed train-test round");
pc.progress();
}
pc.finished();
return v;
}
/** Do some sort of hold-out experiment, as determined by the splitter */
static public Evaluation evaluate(ClassifierLearner learner,Dataset d,Splitter<Example> splitter)
{
Evaluation v = new Evaluation(d.getSchema());
Dataset.Split s = d.split(splitter);
ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
Dataset trainData = s.getTrain(k);
Dataset testData = s.getTest(k);
log.info("splitting with "+splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
Classifier c = new DatasetClassifierTeacher(trainData).train(learner);
if (DEBUG) log.debug("classifier for fold "+(k+1)+"/"+s.getNumPartitions()+" is:\n" + c);
v.extend( c, testData, k );
log.info("splitting with "+splitter+", completed train-test round");
pc.progress();
}
pc.finished();
return v;
}
/** Do some sort of hold-out experiment, as determined by the splitter */
static public MultiEvaluation multiEvaluate(ClassifierLearner learner,MultiDataset d,Splitter<MultiExample> splitter)
{
return multiEvaluate(learner, d, splitter, false);
}
/** Do some sort of hold-out experiment, as determined by the splitter */
static public MultiEvaluation multiEvaluate(ClassifierLearner learner,MultiDataset d,Splitter<MultiExample> splitter, boolean cross)
{
MultiEvaluation v = new MultiEvaluation(d.getMultiSchema());
MultiDataset.MultiSplit s = d.MultiSplit(splitter);
ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
//for (int k=0; k<1; k++) {
MultiDataset trainData = s.getTrain(k);
if(cross) trainData=trainData.annotateData();
MultiDataset testData = s.getTest(k);
log.info("splitting with "+splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
MultiClassifier c = new MultiDatasetClassifierTeacher(trainData).train(learner);
//if(cross) testData=testData.annotateData(c);
if(cross) {
AbstractInstanceTransform transformer = new PredictedClassTransform(c);
c = new TransformingMultiClassifier(c, transformer);
}
if (DEBUG) log.debug("classifier for fold "+(k+1)+"/"+s.getNumPartitions()+" is:\n" + c);
v.extend( c, testData);
log.info("splitting with "+splitter+", completed train-test round");
pc.progress();
}
pc.finished();
return v;
}
/** Do some sort of hold-out experiment, as determined by the splitter */
static public Evaluation evaluate(SequenceClassifierLearner learner,SequenceDataset d,Splitter<Example[]> splitter)
{
Evaluation v = new Evaluation(d.getSchema());
Dataset.Split s = d.splitSequence(splitter);
ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
SequenceDataset trainData = (SequenceDataset)s.getTrain(k);
SequenceDataset testData = (SequenceDataset)s.getTest(k);
log.info("splitting with "+splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
SequenceClassifier c = new DatasetSequenceClassifierTeacher(trainData).train(learner);
if (DEBUG) log.debug("classifier for fold "+(k+1)+"/"+s.getNumPartitions()+" is:\n" + c);
v.extend( c, testData );
log.info("splitting with "+splitter+", completed train-test round");
pc.progress();
}
pc.finished();
return v;
}
/** Do some sort of hold-out experiment, as determined by the splitter */
static public Evaluation evaluate(SemiSupervisedClassifierLearner learner,SemiSupervisedDataset d,Splitter<Example> splitter)
{
Evaluation v = new Evaluation(d.getSchema());
Dataset.Split s = d.split(splitter);
ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
SemiSupervisedDataset trainData = (SemiSupervisedDataset)s.getTrain(k); // Use the Interface ?
SemiSupervisedDataset testData = (SemiSupervisedDataset)s.getTest(k);
log.info("splitting with "+splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
SemiSupervisedClassifier c = new DatasetSemiSupervisedClassifierTeacher(trainData).train(learner);
if (DEBUG) log.debug("classifier for fold "+(k+1)+"/"+s.getNumPartitions()+" is:\n" + c);
v.extend( c, testData, k );
log.info("splitting with "+splitter+", completed train-test round");
pc.progress();
}
pc.finished();
return v;
}
/** Do a train and test experiment */
static public Evaluation evaluate(ClassifierLearner learner,Dataset trainData,Dataset testData)
{
Splitter<Example> trainTestSplitter = new FixedTestSetSplitter<Example>(testData.iterator());
return evaluate(learner,trainData,trainTestSplitter);
}
/** Do a train and test experiment */
static public Evaluation
evaluate(SequenceClassifierLearner learner,SequenceDataset trainData,SequenceDataset testData)
{
Splitter<Example[]> trainTestSplitter = new FixedTestSetSplitter<Example[]>(testData.sequenceIterator());
return evaluate(learner,trainData,trainTestSplitter);
}
/** Do a train and test experiment */
static public Evaluation
evaluate(SemiSupervisedClassifierLearner learner,SemiSupervisedDataset trainData,SemiSupervisedDataset testData)
{
Splitter<Example> trainTestSplitter = new FixedTestSetSplitter<Example>(testData.iterator());
return evaluate(learner,trainData,trainTestSplitter);
}
/** Return the log loss on an example with known true class. */
static public double logLoss(BinaryClassifier c, Example e)
{
return Math.log( 1.0 + Math.exp( e.getLabel().numericLabel() * c.score(e) ) );
}
/** Return the average log loss on a dataset. */
static public double logLoss(BinaryClassifier c,Dataset d)
{
double loss = 0;
for (Iterator<Example> i=d.iterator(); i.hasNext(); ) {
Example e = i.next();
loss += logLoss(c, e);
}
return loss/d.size();
}
/** Return the error rate of a classifier on a dataset. */
static public double errorRate(Classifier c,Dataset d)
{
double errors = 0;
for (Iterator<Example> i=d.iterator(); i.hasNext(); ) {
Example e = i.next();
if (! c.classification(e).isCorrect( e.getLabel())) {
errors++;
}
}
return errors/d.size();
}
}