package edu.cmu.minorthird.classify.sequential; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import junit.framework.Test; import junit.framework.TestSuite; import org.apache.log4j.Level; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.AbstractClassificationChecks; import edu.cmu.minorthird.classify.SampleDatasets; import edu.cmu.minorthird.classify.experiments.Evaluation; /** * * This class is responsible for... * * @author ksteppe */ public class CrfTest extends AbstractClassificationChecks { Logger log = Logger.getLogger(this.getClass()); /** * Standard test class constructior for CrfTest * @param name Name of the test */ public CrfTest(String name) { super(name); } /** * Convinence constructior for CrfTest */ public CrfTest() { super("CrfTest"); } /** * setUp to run before each test */ protected void setUp() { org.apache.log4j.Logger.getRootLogger().removeAllAppenders(); org.apache.log4j.BasicConfigurator.configure(); log.setLevel(Level.DEBUG); super.setCheckStandards(false); //TODO add initializations if needed } /** * clean up to run after each test */ protected void tearDown() { //TODO clean up resources if needed } /** * Creates a TestSuite from all testXXX methods * @return TestSuite */ public static Test suite() { return new TestSuite(CrfTest.class); } /** * Run the full suite of tests with text output * @param args - unused */ public static void main(String args[]) { junit.textui.TestRunner.run(suite()); } // Test the basic functions of CRFLearner to make sure they are working properly public void testBasicCRF() { double[] refs = new double[]{0.0, // Error Rate 0.0, // std. deviation of Error Rate 0.0, // Balanced Error Rate 0.0, // Error Rate on POS 0.0, // std. deviation of Error Ratr on POS 0.0, // Error Rate on NEG 0.0, // std. deviation of Error Ratr on NEG 1.0, // Average Precision 1.0, // Maximum F1 3.277534399186934, // Average Log Loss 1.0, // Recall 1.0, // Precision 1.0, // F1 1.0}; // Kappa CRFLearner l = new CRFLearner(); SequenceClassifier c = new DatasetSequenceClassifierTeacher(SampleDatasets.makeToySequenceData()).train(l); // Evaluate it immediately saving the stats Evaluation e = new Evaluation(SampleDatasets.makeToySequenceData().getSchema()); e.extend(c, SampleDatasets.makeToySequenceTestData()); checkStats(e.summaryStatistics(), refs); } // Test the SegmentCRFLearner subclass of CRFLearner to make sure that its basic // functions are working properly public void testSegmentCRF() { double[] refs = new double[]{0.0, // Error Rate 0.0, // std. deviation of Error Rate 0.0, // Balanced Error Rate 0.0, // Error Rate on POS 0.0, // std. deviation of Error Ratr on POS 0.0, // Error Rate on NEG 0.0, // std. deviation of Error Ratr on NEG 1.0, // Average Precision 1.0, // Maximum F1 3.277534399186934, // Average Log Loss 1.0, // Recall 1.0, // Precision 1.0, // F1 1.0}; // Kappa SegmentCRFLearner l = new SegmentCRFLearner(); SequenceClassifier c = new DatasetSequenceClassifierTeacher(SampleDatasets.makeToySequenceData()).train(l); // Evaluate it immediately saving the stats Evaluation e = new Evaluation(SampleDatasets.makeToySequenceData().getSchema()); e.extend(c, SampleDatasets.makeToySequenceTestData()); checkStats(e.summaryStatistics(), refs); } /** * Test a full cycle of training, testing, saving (serializing), loading, and testing again.<br> * <br> * This test was added when feature names were changed over from using the old Feature.Factory.getId() * method (or Feature.getNumericName(), which calls getId()) to the newer FeatureIdFactory methods. **/ public void testSerialization() { try { // Create a classifier using the CRFLearner and the toyTrain dataset CRFLearner l = new CRFLearner(); SequenceClassifier c1 = new DatasetSequenceClassifierTeacher(SampleDatasets.makeToySequenceData()).train(l); // Evaluate it immediately saving the stats Evaluation e1 = new Evaluation(SampleDatasets.makeToySequenceData().getSchema()); e1.extend(c1, SampleDatasets.makeToySequenceTestData()); double[] stats1 = new double[4]; stats1[0] = e1.errorRate(); stats1[1] = e1.averagePrecision(); stats1[2] = e1.maxF1(); stats1[3] = e1.averageLogLoss(); // Serialize the classifier to disk ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("CRFTest.classifier"))); out.writeObject(c1); out.flush(); out.close(); // Load it back in. ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream("CRFTest.classifier"))); SequenceClassifier c2 = (SequenceClassifier)in.readObject(); in.close(); // Evaluate again saving the stats Evaluation e2 = new Evaluation(SampleDatasets.makeToySequenceData().getSchema()); e2.extend(c2, SampleDatasets.makeToySequenceTestData()); //double[] stats2 = e2.summaryStatistics(); double[] stats2 = new double[4]; stats2[0] = e2.errorRate(); stats2[1] = e2.averagePrecision(); stats2[2] = e2.maxF1(); stats2[3] = e2.averageLogLoss(); // Only use the basic stats for now because some of the advanced stats // come back as NaN for both datasets and the check stats method can't // handle NaN's log.info("using Standard stats only (4 of them)"); // Compare the stats produced from each run to make sure they are identical checkStats(stats1, stats2); // Remove the temporary classifier file File theClassifier = new File("CRFTest.classifier"); theClassifier.delete(); } catch (Exception e) { e.printStackTrace(); } } }