package edu.cmu.minorthird.classify; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.URI; import java.net.URL; import java.util.Random; import junit.framework.Test; import junit.framework.TestSuite; import org.apache.log4j.Level; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.algorithms.svm.SVMClassifier; import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner; import edu.cmu.minorthird.classify.experiments.Evaluation; /** * * This class is responsible for testing Libsvm wrappers * * @author ksteppe */ public class LibsvmTest extends AbstractClassificationChecks{ Logger log=Logger.getLogger(this.getClass()); private static final String trainFile="edu/cmu/minorthird/classify/testcases/a1a.dat"; //private static final String model="modelFile.dat"; private static final String testFile="edu/cmu/minorthird/classify/testcases/a1a.t.dat"; /** * Standard test class constructior for LibsvmTest * @param name Name of the test */ public LibsvmTest(String name){ super(name); } /** * Convinence constructior for LibsvmTest */ public LibsvmTest(){ super("LibsvmTest"); } /** * setUp to run before each test */ protected void setUp(){ org.apache.log4j.Logger.getRootLogger().removeAllAppenders(); org.apache.log4j.BasicConfigurator.configure(); log.setLevel(Level.DEBUG); super.setCheckStandards(false); //TODO add initializations if needed } /** * clean up to run after each test */ protected void tearDown(){ //TODO clean up resources if needed } /** * use wrapper on the provided data, should get same results * as the direct */ public void testWrapper(){ try{ //get datasets URL url=this.getClass().getClassLoader().getResource(trainFile); Dataset trainData=DatasetLoader.loadSVMStyle(new File(new URI(url.toExternalForm()))); url=this.getClass().getClassLoader().getResource(testFile); Dataset testData=DatasetLoader.loadSVMStyle(new File(new URI(url.toExternalForm()))); //send expectations to checkClassifyText() double[] expect= new double[]{ 0.13769470404984424, 0.6011745705024105, 0.6934812760055479, // should be infinity if not calculating probabilities // 1.3132616875183545, Double.POSITIVE_INFINITY, }; super.setCheckStandards(true); super.checkClassify(new SVMLearner(),trainData,testData,expect); }catch(Exception e){ e.printStackTrace(); } } /** * run the svm wrapper on the sample data */ public void testSampleData(){ double[] refs=new double[]{ 0.0,0.0,0.0,0.0,0.0,0.0,0.0, //0-6 are 0 1.0,1.0, //7-8 are 1 1.3132616875182228,1.0,1.0,1.0, //10-12 are 1 1.0 //13 is 1 }; super.checkClassify(new SVMLearner(),SampleDatasets.toyTrain(), SampleDatasets.toyTest(),refs); } /** * Test a full cycle of training, testing, saving (serializing), loading, and testing again. **/ public void testSerialization(){ try{ // Create a classifier using the SVMLearner and the toyTrain dataset SVMLearner l=new SVMLearner(); Classifier c1= new DatasetClassifierTeacher(SampleDatasets.toyTrain()).train(l); File tempFile=File.createTempFile("SVMTest","classifier"); // Evaluate it immediately saving the stats Evaluation e1=new Evaluation(SampleDatasets.toyTrain().getSchema()); e1.extend(c1,SampleDatasets.toyTest(),1); double[] stats1=new double[4]; stats1[0]=e1.errorRate(); stats1[1]=e1.averagePrecision(); stats1[2]=e1.maxF1(); stats1[3]=e1.averageLogLoss(); // Serialize the classifier to disk //ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("SVMTest.classifier"))); ObjectOutputStream out= new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream( tempFile))); out.writeObject(c1); out.flush(); out.close(); // Load it back in. //ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream("SVMTest.classifier"))); ObjectInputStream in= new ObjectInputStream(new BufferedInputStream(new FileInputStream( tempFile))); Classifier c2=(Classifier)in.readObject(); in.close(); // Evaluate again saving the stats Evaluation e2=new Evaluation(SampleDatasets.toyTrain().getSchema()); e2.extend(c2,SampleDatasets.toyTest(),1); //double[] stats2 = e2.summaryStatistics(); double[] stats2=new double[4]; stats2[0]=e2.errorRate(); stats2[1]=e2.averagePrecision(); stats2[2]=e2.maxF1(); stats2[3]=e2.averageLogLoss(); // Only use the basic stats for now because some of the advanced stats // come back as NaN for both datasets and the check stats method can't // handle NaN's log.info("using Standard stats only (4 of them)"); // Compare the stats produced from each run to make sure they are identical checkStats(stats1,stats2); // Remove the temporary classifier file tempFile.delete(); }catch(Exception e){ e.printStackTrace(); } } /** * Test the MultiClass classification stuff. There are two cases to consider: with and without * calculation of probability estimates. The libsvm documentation states that these two cases * may return different classifications. These tests simply classify a sample dataset and * check the stats produced against expected values. */ public void testMultiClassClassification(){ Dataset trainSet=SampleDatasets.makeToy3ClassData(new Random(12345),100); Dataset testSet=SampleDatasets.makeToy3ClassData(new Random(67890),100); try{ // Create a classifier using the SVMLearner and the toyTrain dataset SVMLearner l=new SVMLearner(); // First run the test without probability estimates l.setDoProbabilityEstimates(false); SVMClassifier c1= (SVMClassifier)(new DatasetClassifierTeacher(trainSet) .train(l)); Evaluation e1=new Evaluation(trainSet.getSchema()); e1.extend(c1,testSet,1); double[] stats1=new double[4]; stats1[0]=e1.errorRate(); stats1[1]=e1.averagePrecision(); stats1[2]=e1.maxF1(); stats1[3]=e1.averageLogLoss(); System.out.println("Error Rate: "+e1.errorRate()); System.out.println("Avg Precision: "+e1.averagePrecision()); System.out.println("Max F1: "+e1.maxF1()); System.out.println("Avg Log Loss: "+e1.averageLogLoss()); // The stats we expect the classification to return. double[] expected=new double[4]; expected[0]=0.07; expected[1]=-1.0; expected[2]=-1.0; expected[3]=Double.POSITIVE_INFINITY; // Compare the stats produced from the run without probability estimates with expected values; checkStats(stats1,expected); // // On a small dataset libsvm may return vastly different stats from run to run so for now // this test is commented out. // // Now do it with probability estimates l.setDoProbabilityEstimates(true); SVMClassifier c2= (SVMClassifier)(new DatasetClassifierTeacher(trainSet) .train(l)); Evaluation e2=new Evaluation(trainSet.getSchema()); e2.extend(c2,testSet,1); double[] stats2=new double[4]; stats2[0]=e2.errorRate(); stats2[1]=e2.averagePrecision(); stats2[2]=e2.maxF1(); stats2[3]=e2.averageLogLoss(); System.out.println("Error Rate2: "+e2.errorRate()); System.out.println("Avg Precision2: "+e2.averagePrecision()); System.out.println("Max F1-2: "+e2.maxF1()); System.out.println("Avg Log Loss2: "+e2.averageLogLoss()); // The stats we expect the classification to return. expected[0]=0.08; expected[1]=-1.0; expected[2]=-1.0; expected[3]=1.194999431381944; // Compare the stats produced from the run with probability estimates with expected values. The libsvm // package doesn't always come up with the "exact" same stats, but they are within 0.05 of each other // so update the delta acordingly. setDelta(0.05); checkStats(stats2,expected); }catch(Exception e){ e.printStackTrace(); } } /** * Creates a TestSuite from all testXXX methods * @return TestSuite */ public static Test suite(){ return new TestSuite(LibsvmTest.class); } /** * Run the full suite of tests with text output * @param args - unused */ public static void main(String args[]){ junit.textui.TestRunner.run(suite()); } // // Crap from svm_predict.java // private double[] predict(BufferedReader input,DataOutputStream output, // svm_model model) throws IOException{ // int correct=0; // int total=0; // double error=0; // double sumv=0,sumy=0,sumvv=0,sumyy=0,sumvy=0; // // while(true){ // String line=input.readLine(); // if(line==null) // break; // // StringTokenizer st=new StringTokenizer(line," \t\n\r\f:"); // // double target=atof(st.nextToken()); // int m=st.countTokens()/2; // svm_node[] x=new svm_node[m]; // for(int j=0;j<m;j++){ // x[j]=new svm_node(); // x[j].index=atoi(st.nextToken()); // x[j].value=atof(st.nextToken()); // } // double v=svm.svm_predict(model,x); // if(v==target) // ++correct; // error+=(v-target)*(v-target); // sumv+=v; // sumy+=target; // sumvv+=v*v; // sumyy+=target*target; // sumvy+=v*target; // ++total; // //// output.writeBytes(v+"\n"); // } // log.debug("Accuracy = "+(double)correct/total*100+"% ("+correct+"/"+total+ // ") (classification)\n"); // log.debug("Mean squared error = "+error/total+" (regression)\n"); // log.debug("Squared correlation coefficient = "+ // ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/ // ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+" (regression)\n"); // // double[] rvalues=new double[3]; // rvalues[0]=(double)correct/(double)total; // rvalues[1]=error/(double)total; // rvalues[2]= // ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/ // ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy)); // // return rvalues; // // } // private double[] prediction(String argv[]) throws IOException{ // if(argv.length!=3){ // System.err.print("usage: svm-predict test_file model_file output_file\n"); // System.exit(1); // } // // BufferedReader input=new BufferedReader(new FileReader(argv[0])); // DataOutputStream output=new DataOutputStream(new FileOutputStream(argv[2])); // svm_model model=svm.svm_load_model(argv[1]); // return predict(input,output,model); // } // private static double atof(String s){ // return Double.valueOf(s).doubleValue(); // } // // private static int atoi(String s){ // return Integer.parseInt(s); // } }