package edu.cmu.minorthird.classify;
import junit.framework.TestCase;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.experiments.Tester;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import org.apache.log4j.Logger;
import org.apache.log4j.Level;
/**
* This class...
* @author ksteppe
*/
abstract public class AbstractClassificationChecks extends TestCase
{
protected Logger log = Logger.getLogger(this.getClass());
protected final static ClassifierLearner DEFAULT_LEARNER = new NaiveBayes();
private boolean checkStandardStatsOnly = false;
private double delta = 0.001;
public AbstractClassificationChecks(String name)
{
super(name);
log.setLevel(Level.DEBUG);
}
/**
*
* @param learner
* @param trainData
* @param testData
* @param referenceStats should be error, precision, recall, ??
*/
public void checkClassify(ClassifierLearner learner,Dataset trainData,Dataset testData,double[] referenceStats){
Evaluation v = Tester.evaluate(learner, trainData, testData);
double[] stats;
log.info("checking standard stats only: "+checkStandardStatsOnly);
if(checkStandardStatsOnly){
stats = new double[4];
stats[0] = v.errorRate();
stats[1] = v.averagePrecision();
stats[2] = v.maxF1();
stats[3] = v.averageLogLoss();
}
else{
stats = v.summaryStatistics();
}
if(referenceStats!=null&&stats.length!=referenceStats.length){
throw new IllegalStateException("number of statistics to check is different from the number of reference stats given!");
}
checkStats(stats, referenceStats);
}
protected void checkStats(double[] stats,double[] referenceStats){
log.info("checking "+stats.length+" stats...");
for(int i=0;i<stats.length;i++){
double stat=stats[i];
log.info("Predictedstat("+i+")="+stat);
log.info("Referencestat("+i+")="+referenceStats[i]);
if(referenceStats!=null){
assertEquals(referenceStats[i],stat,delta);
}
}
}
public double getDelta()
{ return delta; }
public void setDelta(double delta)
{ this.delta = delta; }
protected void setCheckStandards(boolean b)
{ checkStandardStatsOnly = b; }
}