package edu.cmu.minorthird.classify; import java.awt.BorderLayout; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.List; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.border.TitledBorder; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes; import edu.cmu.minorthird.classify.experiments.Evaluation.Matrix; import edu.cmu.minorthird.util.MathUtil; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.SmartVanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.Visible; /** * A Tweaked Learner, with an optimization of the precision vs. recall * * @author Giora Unger * Created on May 19, 2005 * * A learner whose score was optimized according to an F_beta() function, * for a given beta. This optimization is used to fine-tune the precision * vs. recall for the underlying classification algorithm. * Values of beta<1.0 favor precision over recall, while values of * beta>1.0 favor recall over precision. beta=1.0 grants equal weight * to both precision and recall. * * <p>Reference: Jason D. M. Rennie, * <i>Derivation of the F-Measure</i>, * http://people.csail.mit.edu/jrennie/writing/fmeasure.pdf */ public class TweakedLearner extends BatchBinaryClassifierLearner{ // inner learner given to this class during construction private BinaryClassifierLearner innerLearner; // the beta according to which F_beta is to be maximized private double beta; // dataset given private Dataset m_dataset; // dataset schema private ExampleSchema schema; // flag indicating whether the given dataset is binary or not private boolean isBinary=true; // value to be returned if a non-binary dataset is given to // precision() or recall() methods private static final int ILLEGAL_VALUE=-1; private static final double UNINITIALIZED=-1; // actual data structure in which the examples are stored, along with // additional fields required for executing the tweaking private List<Row> tweakingTable=new ArrayList<Row>(); // confusion matrix used for efficiently perform the tweaking Matrix cm=null; // logger for this class private static Logger log=Logger.getLogger(TweakedLearner.class); /** ******************************************************************** * Public methods ******************************************************************** */ // TweakedLearner constructor public TweakedLearner(BinaryClassifierLearner innerLearner,double beta){ this.beta=beta; this.innerLearner=innerLearner; } /* * main method of the TweakedLearner class. Recieves a binary * training dataset and then: * 1. Trains on it, based on the innerLearner, namely its inherent * binary classifier. * 2. Tweaks the classifier (or, more precisely, the model this * classifier came up with), so that F_beta is maximized. This is * done by finding a threshold, see details below. * 3. Creates and returns a new TweakedClassifier, with the original * (inner) classifier and the threshold that was found. */ @Override public Classifier batchTrain(Dataset dataset){ // make sure the dataset given is indeed binary this.schema=dataset.getSchema(); isBinary=schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA); if(!isBinary) // make sure dataset is binary { throw new IllegalArgumentException( "Dataset given to TweakedLearner::batchTrain must be a binary dataset"); } if(dataset.size()==0) // make sure dataset is not empty { throw new IllegalArgumentException( "Dataset given to TweakedLearner::batchTrain is empty"); } this.m_dataset=dataset; // get the classifier resulting from training on the given dataset BinaryClassifier bc= (BinaryClassifier)new DatasetClassifierTeacher(m_dataset) .train(innerLearner); // Initialize the data structure required for the tweaking. Please note // that the ExecuteTweaking() method assumes that the rows in this table // are sorted by descending score initializeTable(); // Execute actual tweaking - figure out what threshold works best on the // given dataset w.r.t. F_beta double threshold=executeTweaking(); return new TweakedClassifier(bc,threshold); } /** ******************************************************************** * Getters & Setters ******************************************************************** */ /** * @return Returns the beta. */ public double getBeta(){ return beta; } /** * @param beta The beta to set. */ public void setBeta(double beta){ this.beta=beta; } /** * @return Returns the innerLearner. */ public BinaryClassifierLearner getInnerLearner(){ return innerLearner; } /** * @param learner The innerLearner to set. */ public void setInnerLearner(BinaryClassifierLearner learner){ this.innerLearner=learner; } /** ******************************************************************** * Private methods ******************************************************************** */ // This method initializes tweakingTable, which the data structure used for tweaking // It loops over the examples in the given dataset, and insert them into the // the table. According to the needs of the tweaking process, the rows are then // sorted by descending score (also called posWeight). private void initializeTable(){ int counter=0; for(Iterator<Example> i=m_dataset.iterator();i.hasNext();counter++){ Example ex=i.next(); ClassLabel predicted= innerLearner.getBinaryClassifier().classification(ex); // add example into the tweaking data structure. note that the tweaked // prediction given during initialization is NEG for all examples ! tweakingTable.add(new Row(ex.asInstance(),ex.getLabel(),predicted, ClassLabel.negativeLabel(-1.0))); // debug code //double score = innerLearner.getBinaryClassifier().score(ex); /* log.debug("Example number: "+ counter + ", posWeight: " + predicted.posWeight() + ", Score: " + score + ", Label: " + ex.getLabel()); */ } // sort the table, after it was filled, by descending score sortByScore(); } /* * This method is the very heart of the tweaking process. It assumes that * the tweakingTable data structure was initilized and filled, with a row * for every example. It further assumes that all the examples were given an * initial tweak_prediction of NEG and that they were sorted by descending score * The method then: * 1. Initialize a confusion matrix, based on a NEG prediction to all examples. * 2. For every example, starting with the on ewith highest positive score: * a. set the tweak_prediction to POS * b. update the confusion matrix accordingly * c. calculate precision, recall and F_beta with the new confusion matrix * and fill these values in the tweakingtable data structure * Please note, that in any such iteration, all the examples/rows above * the current example (including itself) have a POS prediction, while all the * examples/rows below the current example have a NEG prediction. * That is, we prectically evaluate the F_beta when the "dividing line" is on * the current example * 3. After all the rows/exmaples are handled, choose the row with the maximal F_beta * 4. Select the score of this row, or more precisely the average between this score * and the next row's score, to be the threshold. * 5. Return this number as the threshold constituting the new TweakedClassifier. * */ private double executeTweaking(){ double threshold=UNINITIALIZED; initConfusionMatrix(); // for every row, find and fill the precision, recall and F_beta // Note, that each row examined is first set to POS for(int i=0;i<tweakingTable.size();++i){ // set dummy prediction of current example to POS getRow(i).tweak_predicted=ClassLabel.positiveLabel(1.0); // update the confusion matrix based on this prediction change updateConfusionMatrix(i); // calculate the precision, recall and F_beta with the updated confusion matrix getRow(i).precision=getCurrentPrecision(); getRow(i).recall=getCurrentRecall(); getRow(i).F_beta=calculateFBeta(getRow(i).precision,getRow(i).recall); /* log.debug("row " + i + ", precision: " + getRow(i).precision + ", recall: " + getRow(i).recall + ", F_beta: " + getRow(i).F_beta + ", score: " + getRow(i).orig_predicted.posWeight()); */ } // choose the threshold row, that is with maximal F_beta // translate its score into the returned threshold int index=maxFBetaEntry(); // if the row that was found is the last row in the table (VERY unlikely), // set the threshold to be its score if((index+1)==tweakingTable.size()){ threshold=getRow(index).orig_predicted.posWeight(); }else // otherwise, set it to be the average between this row's score and { // the next row's score double maxRowScore=getRow(index).orig_predicted.posWeight(); double nextRowScore=getRow(index+1).orig_predicted.posWeight(); threshold=(maxRowScore+nextRowScore)/2; } log.debug("Threshold found: "+threshold+" (in row "+index+")"); return threshold; // return the threshold that was found } /** * Initializes the confusion matrix. This method is called in the first step * of the tweaking process. Please note that at this step, all the examples * are set to have a tweak_predited field of NEG class. */ private void initConfusionMatrix(){ String[] classes=getClasses(); // count up the errors double[][] confused=new double[classes.length][classes.length]; for(int i=0;i<tweakingTable.size();i++){ Row row=getRow(i); confused[classIndexOf(row.actual)][classIndexOf(row.tweak_predicted)]++; } cm=new Matrix(confused); } /* * During the tweaking process, in each iteration a single example * is handled, so that its tweaked_prediction is changed from NEG to POS * This method receives the index (in th etweakingTable) of the current example * and updates the confusion matrix accordingly */ private void updateConfusionMatrix(int index){ Row row=getRow(index); int actual=classIndexOf(row.actual); int p=classIndexOf(ExampleSchema.POS_CLASS_NAME); int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME); // the confusion matrix (cm) is built as [actual][predicted] cm.values[actual][p]++; cm.values[actual][n]--; } // This method simply returns, given precision and recall, the value // of F_beta. It uses the "beta" data member of this class, to decide // which function is to be calculated // The formula used is: // F_beta = (beta+1) * precision * recall / // (beta * precision) + recall // // See also: // <p>Reference: Jason D. M. Rennie, // <i>Derivation of the F-Measure</i>, // http://people.csail.mit.edu/jrennie/writing/fmeasure.pdf private double calculateFBeta(double precision,double recall){ double divisor=((beta*precision)+recall); // in case a division by zero will occur, return F_beta=0.0 (instead of NaN) if(divisor==0.0){ log.warn("TweakedLearner::calculateFBeta, divisor of F_beta is zero !!!"); return 0.0; } // in case a division by NaN, return F_beta=0.0 (instead of NaN) if((new Double(divisor)).isNaN()){ log .warn("TweakedLearner::calculateFBeta, divisor of F_beta is a NaN !!!"); return 0.0; } return(((beta+1)*precision*recall)/divisor); } // This method returns the precision based on the current confusion matrix. // Note that during the tweaking process the confusion matrix is iteratively updated // Precision is defined as: // true_positive / (true_positive + false_positive) private double getCurrentPrecision(){ if(!isBinary) return ILLEGAL_VALUE; // to be on the safe side int p=classIndexOf(ExampleSchema.POS_CLASS_NAME); int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME); // the confusion matrix (cm) is built as [actual][predicted] return cm.values[p][p]/(cm.values[p][p]+cm.values[n][p]); } // This method returns the recall based on the current confusion matrix. // Note that during the tweaking process the confusion matrix is iteratively updated // Recall is defined as: // true_positive / (true_positive + false_negative) private double getCurrentRecall(){ if(!isBinary) return ILLEGAL_VALUE; // to be on the safe side int p=classIndexOf(ExampleSchema.POS_CLASS_NAME); int n=classIndexOf(ExampleSchema.NEG_CLASS_NAME); // the confusion matrix (cm) is built as [actual][predicted] return cm.values[p][p]/(cm.values[p][p]+cm.values[p][n]); } /** ******************************************************************** * Private convenience methods ******************************************************************** */ // sort the tweakingTable, after it was filled, by descending score private void sortByScore(){ Collections.sort(tweakingTable,new Comparator<Row>(){ @Override public int compare(Row a,Row b){ return MathUtil.sign(b.orig_predicted.posWeight()-a.orig_predicted.posWeight()); } }); } /* * Returns the index (in tweakingTable) of the Row with maximal F_beta value */ private int maxFBetaEntry(){ double maxFBeta=ILLEGAL_VALUE; // initialize int maxIndex=(int)UNINITIALIZED; // index of the row with maximal F_beta for(int i=0;i<tweakingTable.size();++i){ if(getRow(i).F_beta>maxFBeta){ maxFBeta=getRow(i).F_beta; maxIndex=i; } } if(maxFBeta==ILLEGAL_VALUE){ log .error("In TweakedLearner::maxFBetaEntry, maxFBeta has an illegal value"); } return maxIndex; } private Row getRow(int i){ return tweakingTable.get(i); } private String[] getClasses(){ return schema.validClassNames(); } private int classIndexOf(ClassLabel classLabel){ return classIndexOf(classLabel.bestClassName()); } private int classIndexOf(String classLabelName){ return schema.getClassIndex(classLabelName); } // debug method - simply dumps the tweakingTable data structure to stdout // private void printTable(){ // for(int i=0;i<tweakingTable.size();++i){ // System.out.println("Row number "+i+": "+getRow(i)); // } // } /** ******************************************************************** ******************************************************************** * This class represents the information, about a single example, * needed for executing the tweaking: * 1. The example itself * 2. Its true label/class. Indeed this field can be accessed every time * by using example.getLabel(), but for convenience it is stored in the table * 3. The predicted class (orig_predicted), as given by the original * (untweaked) classifier. * 4. A dummy prediction (tweak_predicted), which is used in the actual * tweaking process. During construction, all rows are initialized as NEG examples, * commensurate with the way the tweaking process is executed. * * Please note that during the tweaking process, examples that were predicted * by the original (untweaked) classifier as POS can have a prediction of NEG, * and vice versa. * * Note also, that the actual score for an example is given using * predicted.posWeight(), where posWeight>0 means the original prediction * of the untweaked classifier was that this example is of a POSITIVE class, * and posWeight<0 means NEGATIVE class. * * In addition, for the actual tweaking process, 3 fields are * maintained for each example/row: * 5. Precision * 6. Recall * 7. F_beta value * * See the documentation of the actual tweaking method, ExecuteTweaking(), * for further details ******************************************************************** ******************************************************************** */ private static class Row implements Serializable{ private static final long serialVersionUID=-4069980043842319180L; transient public Instance instance=null; // the example public ClassLabel actual; // true label public ClassLabel orig_predicted; // predicted label - see documentation above public ClassLabel tweak_predicted; // temporary prediction, for tweaking process public double precision=UNINITIALIZED; public double recall=UNINITIALIZED; public double F_beta=UNINITIALIZED; public Row(Instance i,ClassLabel a,ClassLabel orig_p,ClassLabel tweak_p){ instance=i; actual=a; orig_predicted=orig_p; tweak_predicted=tweak_p; } @Override public String toString(){ return orig_predicted+"\t"+actual+"\t"+instance; } } /** ******************************************************************** ******************************************************************** * A Tweaked Classifier, with an optimization of the precision vs. recall * Please note that this is an internal class of the TweakedLearner class. * It is constructed and returned by the TweakedLearner, based on * an original untweaked binary clasifer, and a threshold which was found * to optimized precision vs. recall * * @author Giora Unger * Created on May 19, 2005 ******************************************************************** ******************************************************************** */ public static class TweakedClassifier extends BinaryClassifier implements Serializable,Visible{ static private final long serialVersionUID=20080128L; private double m_threshold; private BinaryClassifier m_classifier; public TweakedClassifier(BinaryClassifier classifier,double threshold){ m_classifier=classifier; m_threshold=threshold; } @Override public double score(Instance instance){ return m_classifier.score(instance)-m_threshold; } /* (non-Javadoc) * @see edu.cmu.minorthird.util.gui.Visible#toGUI() * * Shows the original (untweaked) classifier, and the threshold * that was found * Code was copied from file CMM.java and adjusted */ @Override public Viewer toGUI(){ final Viewer v=new ComponentViewer(){ static final long serialVersionUID=20080128L; @Override public JComponent componentFor(Object o){ TweakedClassifier c=(TweakedClassifier)o; JPanel mainPanel=new JPanel(); mainPanel.setLayout(new BorderLayout()); mainPanel.add(new JLabel("Optimal threshold for TweakedClassifier="+ c.m_threshold),BorderLayout.NORTH); mainPanel.add(new JLabel("Original classifier before tweaking:"), BorderLayout.CENTER); Viewer subView=new SmartVanillaViewer(c.m_classifier); subView.setSuperView(this); mainPanel.add(subView,BorderLayout.SOUTH); mainPanel.setBorder(new TitledBorder("TweakedClassifier class")); return new JScrollPane(mainPanel); } }; v.setContent(this); return v; } /* (non-Javadoc) * @see edu.cmu.minorthird.classify.Classifier#explain(edu.cmu.minorthird.classify.Instance) */ @Override public String explain(Instance instance){ StringBuffer buf=new StringBuffer(""); buf.append("Explanation of original untweaked classifier:\n"); buf.append(m_classifier.explain(instance)); buf.append("\nAdjusted score after tweaking = "+score(instance)); return buf.toString(); } @Override public Explanation getExplanation(Instance instance){ Explanation.Node top=new Explanation.Node("TweakedLearner Explanation"); Explanation.Node orig= new Explanation.Node("Explanation of original untweaked classifier"); Explanation.Node origEx= m_classifier.getExplanation(instance).getTopNode(); orig.add(origEx); top.add(orig); Explanation.Node adjusted= new Explanation.Node("\nAdjusted score after tweaking = "+ score(instance)); top.add(adjusted); Explanation ex=new Explanation(top); return ex; } } /** ******************************************************************** ******************************************************************** * Main method for testing purposes ******************************************************************** ******************************************************************** */ public static void main(String[] args){ System.out.println("Started the test program for TweakedLearner"); NaiveBayes nb=new NaiveBayes(); new TweakedLearner(nb,3.0); System.out.println("Created a TweakedLearner"); } }