package edu.cmu.minorthird.text.learn;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.sequential.CMMLearner;
import edu.cmu.minorthird.text.Annotator;
import edu.cmu.minorthird.text.SpanDifference;
import edu.cmu.minorthird.text.TextBase;
import edu.cmu.minorthird.text.TextLabels;
import edu.cmu.minorthird.text.SpanDifference.Looper;
import edu.cmu.minorthird.ui.Recommended;
/**
*
* This class is responsible for...
*
* @author ksteppe
*/
public class SampleExtractionTest extends TestCase
{
static Logger log = Logger.getLogger(SampleExtractionTest.class);
// text base of training data
protected TextBase base;
protected TextLabels labels;
// text base of testing data
protected TextBase testBase;
protected TextLabels testLabels;
// labelString
private String labelString;
/**
* Standard test class constructior for SampleExtractionTest
* @param name Name of the test
*/
public SampleExtractionTest(String name)
{
super(name);
}
/**
* Convinence constructior for SampleExtractionTest
*/
public SampleExtractionTest()
{
super("SampleExtractionTest");
}
/**
* setUp to run before each test
*/
protected void setUp(){
Logger.getRootLogger().removeAllAppenders();
org.apache.log4j.BasicConfigurator.configure();
//TODO add initializations if needed
base = SampleExtractionProblem.trainBase();
labels = SampleExtractionProblem.trainLabels();
//create test date
testBase = SampleExtractionProblem.testBase();
testLabels = SampleExtractionProblem.testLabels();
//convert to Dataset
this.labelString = SampleExtractionProblem.LABEL;
}
/**
* clean up to run after each test
*/
protected void tearDown()
{
//TODO clean up resources if needed
}
/**
* Base test for SampleExtractionTest
*/
public void testSampleExtractionTest()
{
SpanFeatureExtractor fe = new Recommended.TokenFE();
doExtractionTest( new SequenceAnnotatorLearner( new CMMLearner(new VotedPerceptron(), 3), fe),
new double[]{0.93,0.75,0.25,1.0,0.6,0.25});
doExtractionTest( new SequenceAnnotatorLearner( new CMMLearner(new SVMLearner(), 3), fe),
new double[]{0.93,1.0,0.25,1.0,1.0,0.25} );
}
// double array is <precision,recall,tolerance> for train & test
private void doExtractionTest(AnnotatorLearner learner, double[]expected)
{
AnnotatorTeacher annotatorTeacher = new TextLabelsAnnotatorTeacher( labels, labelString );
learner.setAnnotationType( "prediction" );
Annotator learnedAnnotator = annotatorTeacher.train( learner );
TextLabels trainLabels1 = learnedAnnotator.annotatedCopy( labels );
TextLabels testLabels1 = learnedAnnotator.annotatedCopy( testLabels );
//TextBaseViewer.view( testLabels1 );
//TextBaseViewer.view( trainLabels1 );
checkSpans( "prediction", labelString, trainLabels1, expected[0],expected[1],expected[2]);
checkSpans( "prediction", labelString, testLabels1, expected[3],expected[4],expected[5]);
}
private void
checkSpans(String guessType,String truthType,TextLabels labels,double tokRec,double tokPrec,double epsilon)
{
SpanDifference sd = new SpanDifference(labels.instanceIterator(guessType),labels.instanceIterator(truthType));
System.out.println();
System.out.println(sd.toSummary());
System.out.println(sd);
Looper l=sd.differenceIterator();
while(l.hasNext()){
System.out.println(">>"+l.next());
//System.out.println(">>>"+l.next());
}
System.out.println(tokPrec+" "+sd.tokenPrecision()+" "+epsilon);
assertEquals( tokPrec, sd.tokenPrecision(), epsilon );
System.out.println(tokRec+" "+sd.tokenRecall()+" "+epsilon);
assertEquals( tokRec, sd.tokenRecall(), epsilon );
}
/**
* Creates a TestSuite from all testXXX methods
* @return TestSuite
*/
public static Test suite()
{
return new TestSuite(SampleExtractionTest.class);
}
/**
* Run the full suite of tests with text output
* @param args - unused
*/
public static void main(String args[])
{
junit.textui.TestRunner.run(suite());
}
}