/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.experiments; import java.awt.Color; import java.awt.Component; import java.awt.GridBagConstraints; import java.awt.GridBagLayout; import java.awt.event.ActionEvent; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.LineNumberReader; import java.io.PrintStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import javax.swing.AbstractAction; import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.JTable; import javax.swing.JTextField; import javax.swing.table.DefaultTableCellRenderer; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Classifier; import edu.cmu.minorthird.classify.Dataset; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.MutableInstance; import edu.cmu.minorthird.classify.SGMExample; import edu.cmu.minorthird.classify.relational.RealRelationalDataset; import edu.cmu.minorthird.classify.relational.StackedGraphicalLearner; import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedClassifier; import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedDataset; import edu.cmu.minorthird.classify.sequential.SequenceClassifier; import edu.cmu.minorthird.classify.sequential.SequenceDataset; import edu.cmu.minorthird.util.MathUtil; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.Saveable; import edu.cmu.minorthird.util.StringUtil; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.LineCharter; import edu.cmu.minorthird.util.gui.ParallelViewer; import edu.cmu.minorthird.util.gui.VanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerFrame; import edu.cmu.minorthird.util.gui.Visible; /** Stores some detailed results of evaluating a classifier on data. * * @author William Cohen */ public class Evaluation implements Visible,Serializable,Saveable{ private static Logger log=Logger.getLogger(Evaluation.class); // serialization stuff static final long serialVersionUID=20080130L; // private data // all entries static public final int DEFAULT_PARTITION_ID=0; private List<Entry> entryList=new ArrayList<Entry>(); // cached values transient private Matrix cachedPRCMatrix=null; transient private Matrix cachedTPFPMatrix=null; transient private Matrix cachedConfusionMatrix=null; // dataset schema private ExampleSchema schema; // properties private Properties properties=new Properties(); private List<String> propertyKeyList=new ArrayList<String>(); // are all classes binary? private boolean isBinary=true; /** Create an evaluation for databases with this schema */ public Evaluation(ExampleSchema schema){ this.schema=schema; isBinary=schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA); } //classification(RealRelationalDataset dataset) /** Test the classifier on the examples in the relational dataset and store the results. */ public void extend4SGM(StackedGraphicalLearner.StackedGraphicalClassifier c, RealRelationalDataset d,int cvID){ //ProgressCounter pc=new ProgressCounter("classifying","example",d.size()); Map<String,ClassLabel> rlt=c.classification(d); for(Iterator<String> i=rlt.keySet().iterator();i.hasNext();){ String ID=i.next(); ClassLabel predicted=rlt.get(ID); SGMExample example=d.getExampleWithID(ID); if(predicted.bestClassName()==null) throw new IllegalArgumentException( "predicted can't be null! for example: "+example); if(example.getLabel()==null) throw new IllegalArgumentException("predicted can't be null!"); if(log.isDebugEnabled()){ String ok=predicted.isCorrect(example.getLabel())?"Y":"N"; log.debug("ok: "+ok+"\tpredict: "+predicted+"\ton: "+example); } entryList.add(new Entry(example.asInstance(),predicted, example.getLabel(),entryList.size(),cvID)); // calling these extends the schema to cover these classes extendSchema(example.getLabel()); extendSchema(predicted); // clear caches cachedPRCMatrix=null; } } /** Test the classifier on the examples in the dataset and store the results. */ public void extend(Classifier c,Dataset d,int cvID){ ProgressCounter pc=new ProgressCounter("classifying","example",d.size()); for(Iterator<Example> i=d.iterator();i.hasNext();){ Example ex=i.next(); ClassLabel p=c.classification(ex); extend(p,ex,cvID); pc.progress(); } pc.finished(); } /** Test the SequenceClassifier on the examples in the dataset and store the results. */ public void extend(SequenceClassifier c,SequenceDataset d){ for(Iterator<Example[]> i=d.sequenceIterator();i.hasNext();){ Example[] seq=i.next(); ClassLabel[] pred=c.classification(seq); for(int j=0;j<seq.length;j++){ extend(pred[j],seq[j],DEFAULT_PARTITION_ID); } } } /** Test the classifier on the examples in the dataset and store the results. */ public void extend(SemiSupervisedClassifier c,SemiSupervisedDataset d,int cvID){ ProgressCounter pc=new ProgressCounter("classifying","example",d.size()); for(Iterator<Example> i=d.iterator();i.hasNext();){ Example ex=i.next(); ClassLabel p=c.classification(ex); extend(p,ex,cvID); pc.progress(); } pc.finished(); } /** Record the result of predicting the give class label on the given example */ public void extend(ClassLabel predicted,Example example,int cvID){ if(predicted.bestClassName()==null){ // for(String label:predicted.possibleLabels()){ // log.info(label+"="+predicted.getWeight(label)); // } throw new IllegalArgumentException("Best predicted class name is NULL: "+predicted); } if(example.getLabel()==null){ throw new IllegalArgumentException("True label is NULL: "+example); } if(log.isDebugEnabled()){ String ok=predicted.isCorrect(example.getLabel())?"Y":"N"; log.debug("ok: "+ok+"\tpredict: "+predicted+"\ton: "+example); } entryList.add(new Entry(example.asInstance(),predicted,example.getLabel(), entryList.size(),cvID)); // calling these extends the schema to cover these classes extendSchema(example.getLabel()); extendSchema(predicted); // clear caches cachedPRCMatrix=null; } public void setProperty(String prop,String value){ if(properties.getProperty(prop)==null){ propertyKeyList.add(prop); } properties.setProperty(prop,value); } public String getProperty(String prop){ return properties.getProperty(prop,"=unassigned="); } // // low-level access // public ClassLabel getPrediction(int i){ return ((Entry)entryList.get(i)).predicted; } public ClassLabel getActual(int i){ return ((Entry)entryList.get(i)).actual; } public boolean isCorrect(int i){ return getPrediction(i).isCorrect(getActual(i)); } // // simple statistics // /** Weighted total errors. */ public double errors(){ double errs=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.actual.bestClassName()==null) throw new IllegalArgumentException("actual label is null?"); errs+=e.predicted.isCorrect(e.actual)?0:e.w; } return errs; } /** Weighted total errors on examples with partitionID = ID.. */ public double errors(int ID){ double errs=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID==ID){ errs+=e.predicted.isCorrect(e.actual)?0:e.w; } } return errs; } /** Weighted total errors for classes 1 to K. */ public double[] errorsByClass(){ int K=schema.getNumberOfClasses(); double[] err=new double[K]; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); String actualLabel=e.actual.bestClassName(); int index=schema.getClassIndex(actualLabel); err[index]+=e.predicted.isCorrect(e.actual)?0:e.w; } return err; } /** Weighted total errors for classes 1 to K on examples with partitionID = ID. */ public double[] errorsByClass(int ID){ int K=schema.getNumberOfClasses(); double[] err=new double[K]; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID==ID){ String actualLabel=e.actual.bestClassName(); int index=schema.getClassIndex(actualLabel); err[index]+=e.predicted.isCorrect(e.actual)?0:e.w; } } return err; } /** Weighted total errors on POSITIVE examples. */ public double errorsPos(){ if(!isBinary) return -1; double errsPos=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("POS".equals(e.actual.bestClassName())){ errsPos+=e.predicted.isCorrect(e.actual)?0:e.w; } } return errsPos; } /** Weighted total errors on POSITIVE examples with partitionID = ID. */ public double errorsPos(int ID){ if(!isBinary) return -1; double errsPos=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("POS".equals(e.actual.bestClassName())&e.partitionID==ID){ errsPos+=e.predicted.isCorrect(e.actual)?0:e.w; } } return errsPos; } /** Weighted total errors on NEGATIVE examples. */ public double errorsNeg(){ double errsNeg=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("NEG".equals(e.actual.bestClassName())){ errsNeg+=e.predicted.isCorrect(e.actual)?0:e.w; } } return errsNeg; } /** Weighted total errors on NEGATIVE examples with partitionID = ID. */ public double errorsNeg(int ID){ double errsNeg=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("NEG".equals(e.actual.bestClassName())&e.partitionID==ID){ errsNeg+=e.predicted.isCorrect(e.actual)?0:e.w; } } return errsNeg; } /** standard deviation of total errors. */ public double stDevErrors(){ int cvFolds=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID>cvFolds){ cvFolds=e.partitionID+1; } } double mean=errorRate(); double stDev=0.0; for(int k=0;k<cvFolds;k++){ stDev+=Math.pow(errors(k)/numberOfInstances(k)-mean,2)/((double)cvFolds); } return Math.sqrt(stDev); } /** standard deviation of total errors for classes 1 to K. */ public double[] stDevErrorsByClass(){ int K=schema.getNumberOfClasses(); int cvFolds=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID>cvFolds){ cvFolds=e.partitionID+1; } } double[] mean=errorRateByClass(); double[] stdev=new double[K]; for(int k=0;k<cvFolds;k++){ double[] errorsByClass=errorsByClass(k); double[] numerOfExamplesByClass=numberOfExamplesByClass(k); for(int i=0;i<K;i++){ stdev[i]+= Math.pow(errorsByClass[i]/numerOfExamplesByClass[i]-mean[i],2)/ ((double)cvFolds); } } for(int i=0;i<K;i++){ stdev[i]=Math.sqrt(stdev[i]); } return stdev; } /** standard deviation of total errors on POSITIVE examples. */ public double stDevErrorsPos(){ if(!isBinary) return -1; int cvFolds=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID>cvFolds){ cvFolds=e.partitionID+1; } } double mean=errorsPos()/numberOfPositiveExamples(); double variance=0.0; for(int k=0;k<cvFolds;k++){ variance+= Math.pow(errorsPos(k)/numberOfPositiveExamples(k)-mean,2)/ ((double)cvFolds); } return Math.sqrt(variance); } /** standard deviation of total errors on NEGATIVE examples. */ public double stDevErrorsNeg(){ if(!isBinary) return -1; int cvFolds=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID>cvFolds){ cvFolds=e.partitionID+1; } } double mean=errorsNeg()/numberOfNegativeExamples(); double variance=0.0; for(int k=0;k<cvFolds;k++){ variance+= Math.pow(errorsNeg(k)/numberOfNegativeExamples(k)-mean,2)/ ((double)cvFolds); } return Math.sqrt(variance); } /** Total weight of all instances. */ public double numberOfInstances(){ double n=0; for(int i=0;i<entryList.size();i++){ n+=getEntry(i).w; } return n; } /** Total weight of all instances. */ public double numberOfInstances(int ID){ double n=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if(e.partitionID==ID){ n+=e.w; } } return n; } /** Total weight of examples in all classes 1 to K. */ public double[] numberOfExamplesByClass(){ int K=schema.getNumberOfClasses(); double[] wgt=new double[K]; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); String actualLabel=e.actual.bestClassName(); int index=schema.getClassIndex(actualLabel); wgt[index]+=e.w; } return wgt; } /** Total weight of examples in all classes 1 to K with partitionID = ID. */ public double[] numberOfExamplesByClass(int ID){ int K=schema.getNumberOfClasses(); double[] wgt=new double[K]; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); String actualLabel=e.actual.bestClassName(); int index=schema.getClassIndex(actualLabel); if(e.partitionID==ID){ wgt[index]+=e.w; } } return wgt; } /** Total weight of all POSITIVE examples. */ public double numberOfPositiveExamples(){ if(!isBinary) return -1; double n=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("POS".equals(e.actual.bestClassName())){ n+=e.w; } } return n; } /** Total weight of all POSITIVE examples with partitionID = ID. */ public double numberOfPositiveExamples(int ID){ if(!isBinary) return -1; double n=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("POS".equals(e.actual.bestClassName())&e.partitionID==ID){ n+=e.w; } } return n; } /** Total weight of all NEGATIVE examples. */ public double numberOfNegativeExamples(){ double n=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("NEG".equals(e.actual.bestClassName())){ n+=e.w; } } return n; } /** Total weight of all NEGATIVE examples with partitionID = ID. */ public double numberOfNegativeExamples(int ID){ double n=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); if("NEG".equals(e.actual.bestClassName())&e.partitionID==ID){ n+=e.w; } } return n; } /** Error rate. */ public double errorRate(){ if(numberOfInstances()>0){ return errors()/numberOfInstances(); } else{ return 0.0; } } /** Error rate by Class. */ public double[] errorRateByClass(){ int K=schema.getNumberOfClasses(); double[] errRate=new double[K]; double[] err=errorsByClass(); double[] wgt=numberOfExamplesByClass(); for(int i=0;i<K;i++){ if(wgt[i]==0){ errRate[i]=0.0; } else{ errRate[i]=err[i]/wgt[i]; } } return errRate; } /** Error rate on Positive examples. */ public double errorRatePos(){ return errorsPos()/numberOfPositiveExamples(); } /** Error rate on Negative examples. */ public double errorRateNeg(){ return errorsNeg()/numberOfNegativeExamples(); } /** Balanced Error rate. */ public double errorRateBalanced(){ double errorBalanced=0.0; int K=schema.getNumberOfClasses(); double[] errorsByClass=errorsByClass(); double[] numberOfExamplesByClass=numberOfExamplesByClass(); for(int i=0;i<K;i++){ if(numberOfExamplesByClass[i]>0){ errorBalanced+=1.0/(double)K*errorsByClass[i]/numberOfExamplesByClass[i]; } else{ errorBalanced+=0.0; } } return errorBalanced; } /** Recall in the top K, excluding items with score<threshold */ public double recallTopK(int k,double minScore){ if(!isBinary) return -1; if(numberOfPositiveExamples()==0) return 1.0; // special case double lastRecall=0; // detect a postive example double numPositiveExamplesInTopK=0; Matrix m=precisionRecallScore(); for(int i=0;i<Math.min(m.values.length,k);i++){ if(m.values[i][1]>lastRecall&&m.values[i][2]>minScore){ numPositiveExamplesInTopK++; } lastRecall=m.values[i][1]; } return numPositiveExamplesInTopK/numberOfPositiveExamples(); } /** Non-interpolated average precision. */ public double averagePrecision(){ if(!isBinary) return -1; if(numberOfInstances()==0) return Double.NaN; // undefined! double total=0,n=0; Matrix m=precisionRecallScore(); double lastRecall=0; // detect a postive example for(int i=0;i<m.values.length;i++){ if(m.values[i][1]>lastRecall){ n++; total+=m.values[i][0]; } lastRecall=m.values[i][1]; } return total/n; } /** Max f1 values at any cutoff. */ public double maxF1(){ return maxF1(Double.MIN_VALUE); } /** Max f1 values for any threshold above the specified cutoff. */ public double maxF1(double minThreshold){ if(!isBinary) return -1; if(numberOfPositiveExamples()==0) return 1.0; double maxF1=0; Matrix m=precisionRecallScore(); for(int i=0;i<m.values.length;i++){ double p=m.values[i][0]; double r=m.values[i][1]; if((p>0||r>0)&&m.values[i][2]>=minThreshold){ double f1=(2*p*r)/(p+r); maxF1=Math.max(maxF1,f1); } } return maxF1; } public double kappa(){ Matrix cm=confusionMatrix(); double n=entryList.size(); int k=schema.getNumberOfClasses(); if(n<1){ return 0.0; } double[] numActual=new double[k]; double[] numPredicted=new double[k]; double numAgree=0.0; for(int i=0;i<k;i++){ numAgree+=cm.values[i][i]; for(int j=0;j<k;j++){ numActual[i]+=cm.values[i][j]; numPredicted[i]+=cm.values[j][i]; } } double randomAgreement=0.0; for(int i=0;i<k;i++){ randomAgreement+=(numActual[i]/n)*(numPredicted[i]/n); } return (numAgree/n-randomAgreement)/(1.0-randomAgreement); } public int numExamples(){ return entryList.size(); } /** Average logloss on all examples. */ public double averageLogLoss(){ double tot=0; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); double confidence=e.predicted.getWeight(e.actual.bestClassName()); double error=e.predicted.isCorrect(e.actual)?+1:-1; tot+=Math.log(1.0+Math.exp(confidence*error)); } return tot/entryList.size(); } public double precision(){ if(!isBinary) return -1; Matrix cm=confusionMatrix(); int p=classIndexOf(ExampleSchema.POS_CLASS_NAME); int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME); //cm is actual, predicted return cm.values[p][p]/(cm.values[p][p]+cm.values[n][p]); } public double recall(){ if(!isBinary) return -1; Matrix cm=confusionMatrix(); int p=classIndexOf(ExampleSchema.POS_CLASS_NAME); int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME); //cm is actual, predicted return cm.values[p][p]/(cm.values[p][p]+cm.values[p][n]); } public double f1(){ if(!isBinary) return -1; double p=precision(); double r=recall(); return (2*p*r)/(p+r); } public double[] summaryStatistics(){ int K=schema.getNumberOfClasses(); if(isBinary){ double[] stats=new double[(10+2*K)]; stats[0]=errorRate(); stats[1]=stDevErrors(); stats[2]=errorRateBalanced(); double[] err=errorRateByClass(); double[] sd=stDevErrorsByClass(); for(int i=0;i<K;i++){ stats[(2+2*i+1)]=err[i]; stats[(2+2*i+2)]=sd[i]; } stats[(3+2*K)]=averagePrecision(); stats[(4+2*K)]=maxF1(); stats[(5+2*K)]=averageLogLoss(); stats[(6+2*K)]=recall(); stats[(7+2*K)]=precision(); stats[(8+2*K)]=f1(); stats[(9+2*K)]=kappa(); return stats; }else{ double[] stats=new double[(4+2*K)]; stats[0]=errorRate(); stats[1]=stDevErrors(); stats[2]=errorRateBalanced(); double[] err=errorRateByClass(); double[] sd=stDevErrorsByClass(); for(int i=0;i<K;i++){ stats[(2+2*i+1)]=err[i]; stats[(2+2*i+2)]=sd[i]; } stats[(3+2*K)]=kappa(); return stats; } } public String[] summaryStatisticNames(){ int K=schema.getNumberOfClasses(); if(isBinary){ String[] names=new String[(10+2*K)]; names[0]="Error Rate"; names[1]=". std. deviation error rate"; names[2]="Balanced Error Rate"; for(int i=0;i<K;i++){ String classname=schema.getClassName(i); names[(2+2*i+1)]=new String(". error rate on "+classname); names[(2+2*i+2)]=new String(". std. deviation on "+classname); } names[(3+2*K)]="Average Precision"; names[(4+2*K)]="Maximium F1"; names[(5+2*K)]="Average Log Loss"; names[(6+2*K)]="Recall"; names[(7+2*K)]="Precision"; names[(8+2*K)]="F1"; names[(9+2*K)]="Kappa"; return names; }else{ String[] names=new String[(4+2*K)]; names[0]="Error Rate"; names[1]=". std. deviation error rate"; names[2]="Balanced Error Rate"; for(int i=0;i<K;i++){ String classname=schema.getClassName(i); names[(2+2*i+1)]=new String(". error rate on "+classname); names[(2+2*i+2)]=new String(". std. deviation on "+classname); } names[(3+2*K)]="Kappa"; return names; } } // // complex statistics, ie ones that are harder to visualize // public static class Matrix{ public double[][] values; public Matrix(double[][] values){ this.values=values; } @Override public String toString(){ StringBuffer buf=new StringBuffer(""); for(int i=0;i<values.length;i++){ buf.append(StringUtil.toString(values[i])+"\n"); } return buf.toString(); } public double getValue(int row,int col){ return values[row][col]; } } /** Return a confusion matrix. */ public Matrix confusionMatrix(){ if(cachedConfusionMatrix!=null) return cachedConfusionMatrix; String[] classes=getClasses(); // count up the errors double[][] confused=new double[classes.length][classes.length]; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); confused[classIndexOf(e.actual)][classIndexOf(e.predicted)]++; } cachedConfusionMatrix=new Matrix(confused); return cachedConfusionMatrix; } public double numErrors(){ Matrix m=confusionMatrix(); double errors=m.getValue(0,1)+m.getValue(1,0); return errors; } public String[] getClasses(){ return schema.validClassNames(); } /** Return array of true positive,false positive,logitScore. */ public Matrix TPfractionFPfractionScore(){ if(cachedTPFPMatrix!=null) return cachedTPFPMatrix; if(!isBinary) throw new IllegalArgumentException( "can't compute precisionRecallScore for non-binary data"); byBinaryScore(); int allActualPos=0; int allActualNeg=0; int lastIndexOfActualPos=0; int firstIndexOfActualNeg=0; boolean notFoundYet=true; ProgressCounter pc= new ProgressCounter("counting positive examples","examples",entryList .size()); for(int i=0;i<entryList.size();i++){ if(getEntry(i).actual.isPositive()){ allActualPos++; lastIndexOfActualPos=i; }else{ allActualNeg++; if(notFoundYet){ firstIndexOfActualNeg=i; notFoundYet=false; } } pc.progress(); } pc.finished(); //System.out.println("all pos = "+allActualPos+" all neg = "+allActualNeg); int length=Math.abs(lastIndexOfActualPos-firstIndexOfActualNeg)+4; int min=Math.min(lastIndexOfActualPos,firstIndexOfActualNeg); int max=Math.max(lastIndexOfActualPos,firstIndexOfActualNeg); //System.out.println("min="+min+" max="+max+" length="+length); double truePosSoFar=0; double falsePosSoFar=0; double tpf=1,fpf=1,score=0; ProgressCounter pc2= new ProgressCounter("computing statistics","examples",entryList.size()); double[][] result=new double[length][3]; for(int i=0;i<entryList.size();i++){ Entry e=getEntry(i); score=e.predicted.posWeight(); if(e.actual.isPositive()) truePosSoFar++; else falsePosSoFar++; if(allActualPos>0) tpf=truePosSoFar/allActualPos; if(allActualNeg>0) fpf=falsePosSoFar/allActualNeg; if(i==0){ result[0][0]=0.0; result[0][1]=0.0; result[0][2]=score; } if(i>=(min-1)&i<=max){ result[i-min+2][0]=tpf; result[i-min+2][1]=fpf; result[i-min+2][2]=score; //System.out.println("tpf="+tpf+" fpf="+fpf+" score="+score); } result[length-1][0]=1.0; result[length-1][1]=1.0; result[length-1][2]=score; pc2.progress(); } pc2.finished(); cachedTPFPMatrix=new Matrix(result); return cachedTPFPMatrix; } /** Return actual ROC curve. * At most 1000 points are kept. * */ public Matrix thousandPointROC(){ Matrix m=TPfractionFPfractionScore(); int N=m.values.length-2; if(N>1000){ double[][] v=new double[1002][3]; int mod=N/1000; // add (0,0,score) v[0][0]=m.values[0][0]; v[0][1]=m.values[0][1]; v[0][2]=m.values[0][2]; // fill in 1000 values for(int i=1;i<=1000;i++){ int k=(i-1)*mod; v[i][0]=m.values[k+1][0]; v[i][1]=m.values[k+1][1]; v[i][2]=m.values[k+1][2]; } // add (1,1,score) v[1001][0]=m.values[N+1][0]; v[1001][1]=m.values[N+1][1]; v[1001][2]=m.values[N+1][2]; return new Matrix(v); }else{ return m; } } /** Return array of precision,recall,logitScore. */ public Matrix precisionRecallScore(){ if(cachedPRCMatrix!=null) return cachedPRCMatrix; if(!isBinary) throw new IllegalArgumentException( "can't compute precisionRecallScore for non-binary data"); byBinaryScore(); int allActualPos=0; int lastIndexOfActualPos=0; ProgressCounter pc= new ProgressCounter("counting positive examples","examples",entryList .size()); for(int i=0;i<entryList.size();i++){ if(getEntry(i).actual.isPositive()){ allActualPos++; lastIndexOfActualPos=i; } pc.progress(); } pc.finished(); double truePosSoFar=0; double falsePosSoFar=0; double precision=1,recall=1,score=0; ProgressCounter pc2= new ProgressCounter("computing statistics","examples", lastIndexOfActualPos); double[][] result=new double[lastIndexOfActualPos+1][3]; for(int i=0;i<=lastIndexOfActualPos;i++){ Entry e=getEntry(i); score=e.predicted.posWeight(); if(e.actual.isPositive()) truePosSoFar++; else falsePosSoFar++; if(truePosSoFar+falsePosSoFar>0) precision=truePosSoFar/(truePosSoFar+falsePosSoFar); if(allActualPos>0) recall=truePosSoFar/allActualPos; result[i][0]=precision; result[i][1]=recall; result[i][2]=score; pc2.progress(); } pc2.finished(); cachedPRCMatrix=new Matrix(result); return cachedPRCMatrix; } /** Return eleven-point interpolated precision. * Precisely, result is an array p[] of doubles * such that p[i] is the maximal precision value * for any point with recall>=i/10. * */ public double[] elevenPointPrecision(){ Matrix m=precisionRecallScore(); //System.out.println("prs = "+m); double[] p=new double[11]; p[0]=1.0; for(int i=0;i<m.values.length;i++){ double r=m.values[i][1]; //System.out.println("row "+i+", recall "+r+": "+StringUtil.toString(m.values[i])); for(int j=1;j<=10;j++){ if(r>=j/10.0){ p[j]=Math.max(p[j],m.values[i][0]); //System.out.println("update p["+j+"] => "+p[j]); } } } return p; } // // views of data // /** Detailed view. */ @Override public String toString(){ StringBuffer buf=new StringBuffer(""); for(int i=0;i<entryList.size();i++){ buf.append(getEntry(i)+"\n"); } return buf.toString(); } static public class PropertyViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; @Override public JComponent componentFor(Object o){ final Evaluation e=(Evaluation)o; final JPanel panel=new JPanel(); final JTextField propField=new JTextField(10); final JTextField valField=new JTextField(10); final JTable table=makePropertyTable(e); final JScrollPane tableScroller=new JScrollPane(table); final JButton addButton= new JButton(new AbstractAction("Insert Property"){ static final long serialVersionUID=20080130L; @Override public void actionPerformed(ActionEvent event){ e.setProperty(propField.getText(),valField.getText()); tableScroller.getViewport().setView(makePropertyTable(e)); tableScroller.revalidate(); panel.revalidate(); } }); panel.setLayout(new GridBagLayout()); GridBagConstraints gbc=fillerGBC(); //gbc.fill = GridBagConstraints.HORIZONTAL; gbc.gridwidth=3; panel.add(tableScroller,gbc); panel.add(addButton,myGBC(0)); panel.add(propField,myGBC(1)); panel.add(valField,myGBC(2)); return panel; } private GridBagConstraints myGBC(int col){ GridBagConstraints gbc=fillerGBC(); gbc.fill=GridBagConstraints.HORIZONTAL; gbc.gridx=col; gbc.gridy=1; return gbc; } private JTable makePropertyTable(final Evaluation e){ Object[][] table=new Object[e.propertyKeyList.size()][2]; for(int i=0;i<e.propertyKeyList.size();i++){ table[i][0]=e.propertyKeyList.get(i); table[i][1]=e.properties.get(e.propertyKeyList.get(i)); } String[] colNames=new String[]{"Property","Property's Value"}; return new JTable(table,colNames); } } public class SummaryViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; @Override public JComponent componentFor(Object o){ Evaluation e=(Evaluation)o; double[] ss=e.summaryStatistics(); String[] ssn=e.summaryStatisticNames(); Object[][] oss=new Object[ss.length][2]; for(int i=0;i<ss.length;i++){ oss[i][0]=ssn[i]; oss[i][1]=new Double(ss[i]); } JTable jtable=new JTable(oss,new String[]{"Statistic","Value"}); jtable.setDefaultRenderer(Object.class,new MyTableCellRenderer()); jtable.setVisible(true); return new JScrollPane(jtable); } } static public class ElevenPointPrecisionViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; @Override public JComponent componentFor(Object o){ Evaluation e=(Evaluation)o; double[] p=e.elevenPointPrecision(); LineCharter lc=new LineCharter(); lc.startCurve("Interpolated Precision"); for(int i=0;i<p.length;i++){ lc.addPoint(i/10.0,p[i]); } return lc.getPanel("11-Pt Interpolated Precision vs. Recall","Recall", "Precision"); } } static public class ROCViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; @Override public JComponent componentFor(Object o){ Evaluation e=(Evaluation)o; Matrix p=e.thousandPointROC(); LineCharter lc=new LineCharter(); lc.startCurve("Actual ROC"); for(int i=0;i<p.values.length;i++){ lc.addPoint(p.values[i][1],p.values[i][0]); //System.out.println(p.values[i][0]+" "+p.values[i][1]); // Uncomment for MATLAB } // compute area under the curve double area=0.0; for(int i=0;i<(p.values.length-1);i++){ area+= (p.values[i][0]+p.values[i+1][0])* (p.values[i+1][1]-p.values[i][1])/2.0; //System.out.println("("+p.values[i][0]+"+"+p.values[i+1][0]+") * ("+p.values[i+1][1]+"-"+p.values[i][1]+") /2.0"); } return lc.getPanel("Actual ROC Curve", "False Positive / All Negative (AUC = "+area+")", "True Positive / All Positive"); } } static public class ConfusionMatrixViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; @Override public JComponent componentFor(Object o){ Evaluation e=(Evaluation)o; JPanel panel=new JPanel(); Matrix m=e.confusionMatrix(); String[] classes=e.getClasses(); panel.setLayout(new GridBagLayout()); //add( new JLabel("Actual class"), cmGBC(0,0) ); GridBagConstraints gbc=cmGBC(0,1); gbc.gridwidth=classes.length; panel.add(new JLabel("Predicted Class"),gbc); for(int i=0;i<classes.length;i++){ panel.add(new JLabel(classes[i]),cmGBC(1,i+1)); } for(int i=0;i<classes.length;i++){ panel.add(new JLabel(classes[i]),cmGBC(i+2,0)); for(int j=0;j<classes.length;j++){ panel.add(new JLabel(Double.toString(m.values[i][j])),cmGBC(i+2,j+1)); } } return panel; } private GridBagConstraints cmGBC(int i,int j){ GridBagConstraints gbc=new GridBagConstraints(); //gbc.fill = GridBagConstraints.BOTH; gbc.weightx=gbc.weighty=0; gbc.gridy=i; gbc.gridx=j; gbc.ipadx=gbc.ipady=20; return gbc; } } /** Print summary statistics */ public void summarize(){ double[] stats=summaryStatistics(); String[] statNames=summaryStatisticNames(); int maxLen=0; for(int i=0;i<statNames.length;i++){ maxLen=Math.max(statNames[i].length(),maxLen); } for(int i=0;i<statNames.length;i++){ System.out.print(statNames[i]+": "); for(int j=0;j<maxLen-statNames[i].length();j++) System.out.print(" "); System.out.println(stats[i]); } } @Override public Viewer toGUI(){ ParallelViewer main=new ParallelViewer(); main.addSubView("Summary",new SummaryViewer()); main.addSubView("Properties",new PropertyViewer()); if(isBinary) main.addSubView("11Pt Precision/Recall",new ElevenPointPrecisionViewer()); if(isBinary) main.addSubView(" ROC & AUC ",new ROCViewer()); main.addSubView("Confusion Matrix",new ConfusionMatrixViewer()); main.addSubView("Debug",new VanillaViewer()); main.setContent(this); return main; } // // one entry in the evaluation // private static class Entry implements Serializable{ private static final long serialVersionUID=-4069980043842319179L; transient public Instance instance=null; public int partitionID; public int index; public ClassLabel predicted,actual; //public int h; public double w=1.0; public Entry(Instance i,ClassLabel p,ClassLabel a,int k,int id){ instance=i; predicted=p; actual=a; index=k; partitionID=id; //h=instance.hashCode(); } @Override public String toString(){ //double w=predicted.bestWeight(); return predicted+"\t"+actual+"\t"+instance; } } // // implement Saveable // final static public String EVAL_FORMAT_NAME="Minorthird Evaluation"; final static public String EVAL_EXT=".eval"; @Override public String[] getFormatNames(){ return new String[]{EVAL_FORMAT_NAME}; } @Override public String getExtensionFor(String format){ return EVAL_EXT; } @Override public void saveAs(File file,String formatName) throws IOException{ save(file); } @Override public Object restore(File file) throws IOException{ return load(file); } // // public void save(File file) throws IOException{ PrintStream out= new PrintStream(new GZIPOutputStream(new FileOutputStream(file))); save(out); } public void save(PrintStream out) throws IOException{ out.println(StringUtil.toString(schema.validClassNames())); for(Iterator<String> i=propertyKeyList.iterator();i.hasNext();){ String prop=(String)i.next(); String value=properties.getProperty(prop); out.println(prop+"="+value); } byOriginalPosition(); for(Iterator<Entry> i=entryList.iterator();i.hasNext();){ Entry e=(Entry)i.next(); out.println(e.predicted.bestClassName()+" "+e.predicted.bestWeight()+" "+ e.actual.bestClassName()); } out.close(); } static public Evaluation load(File file) throws IOException{ // disabled to avoid looping, since this is how we now de-serialize // first try loading a serialized version //try { return (Evaluation)IOUtil.loadSerialized(file); } catch (Exception ex) { ; } LineNumberReader in= new LineNumberReader(new InputStreamReader(new GZIPInputStream( new FileInputStream(file)))); String line=in.readLine(); if(line==null) throw new IllegalArgumentException("no class list on line 1 of file "+ file.getName()); String[] classes=line.substring(1,line.length()-1).split(","); ExampleSchema schema=new ExampleSchema(classes); Evaluation result=new Evaluation(schema); while((line=in.readLine())!=null){ if(line.indexOf('=')>=0){ // property String[] propValue=line.split("="); if(propValue.length==2){ result.setProperty(propValue[0],propValue[1]); }else if(propValue.length==1){ result.setProperty(propValue[0],""); }else{ throw new IllegalArgumentException(file.getName()+" line "+ in.getLineNumber()+": illegal format"); } }else{ String[] words=line.split(" "); if(words.length<3) throw new IllegalArgumentException(file.getName()+" line "+ in.getLineNumber()+": illegal format"); ClassLabel predicted=new ClassLabel(words[0],StringUtil.atof(words[1])); ClassLabel actual=new ClassLabel(words[2]); //double instanceWeight = StringUtil.atof(words[3]); MutableInstance instance=new MutableInstance("dummy"); //instance.setWeight( instanceWeight ); Example example=new Example(instance,actual); result.extend(predicted,example,DEFAULT_PARTITION_ID); } } in.close(); return result; } // // getters / setters // /** Returns whether this Evaluation refers to a binary classifier */ public boolean isBinary(){ return this.isBinary; } /** Returns whether the ExampleSchema this Evaluation is based upon */ public ExampleSchema getSchema(){ return this.schema; } // // convenience methods // private Entry getEntry(int i){ return (Entry)entryList.get(i); } private int classIndexOf(ClassLabel classLabel){ return classIndexOf(classLabel.bestClassName()); } private int classIndexOf(String classLabelName){ return schema.getClassIndex(classLabelName); } private void extendSchema(ClassLabel classLabel){ //System.out.println("classLabel: "+classLabel); if(!classLabel.isBinary()) isBinary=false; int r=classIndexOf(classLabel.bestClassName()); if(r<0){ //System.out.println("extending"); // extend the schema //Add the provided label to the set of valid values //for the class using the extend method on the //schema object schema.extend(classLabel.bestClassName()); //commented old code //String[] currentNames = schema.validClassNames(); //String[] newNames = new String[currentNames.length+1]; //for (int i=0; i<currentNames.length; i++) newNames[i] = currentNames[i]; //newNames[currentNames.length] = classLabel.bestClassName(); } } private void byBinaryScore(){ Collections.sort(entryList,new Comparator<Entry>(){ @Override public int compare(Entry a,Entry b){ return MathUtil.sign(b.predicted.posWeight()-a.predicted.posWeight()); } }); } private void byOriginalPosition(){ Collections.sort(entryList,new Comparator<Entry>(){ @Override public int compare(Entry a,Entry b){ return a.index-b.index; } }); } // table renderer public class MyTableCellRenderer extends DefaultTableCellRenderer{ static final long serialVersionUID=20080130L; @Override public Component getTableCellRendererComponent(JTable table,Object value, boolean isSelected,boolean hasFocus,int row,int column){ JLabel label= (JLabel)super.getTableCellRendererComponent(table,value,isSelected, hasFocus,row,column); if((row%2)!=0){ label.setBackground(Color.lightGray); label.setOpaque(true); }else{ label.setBackground(Color.white); label.setOpaque(true); } return label; } } // // test routine // static public void main(String[] args){ try{ Evaluation v=Evaluation.load(new File(args[0])); if(args.length>1) v.save(new File(args[1])); new ViewerFrame("From file "+args[0],v.toGUI()); }catch(Exception e){ System.out .println("usage: Evaluation [serializedFile|evaluationFile] [evaluationFile]"); e.printStackTrace(); } } }