/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */ package cc.mallet.classify; import java.util.ArrayList; import cc.mallet.classify.Classifier; import cc.mallet.pipe.Pipe; import cc.mallet.types.Alphabet; import cc.mallet.types.FeatureVector; import cc.mallet.types.Instance; import cc.mallet.types.Label; import cc.mallet.types.LabelVector; public class ConfidencePredictingClassifier extends Classifier { Classifier underlyingClassifier; Classifier confidencePredictingClassifier; double totalCorrect; double totalIncorrect; double totalIncorrectIncorrect; double totalIncorrectCorrect; int numCorrectInstances; int numIncorrectInstances; int numConfidenceCorrect; int numFalsePositive; int numFalseNegative; public ConfidencePredictingClassifier (Classifier underlyingClassifier, Classifier confidencePredictingClassifier) { super (underlyingClassifier.getInstancePipe()); this.underlyingClassifier = underlyingClassifier; this.confidencePredictingClassifier = confidencePredictingClassifier; // for testing confidence accuracy totalCorrect = 0.0; totalIncorrect = 0.0; totalIncorrectIncorrect = 0.0; totalIncorrectCorrect = 0.0; numCorrectInstances = 0; numIncorrectInstances = 0; numConfidenceCorrect = 0; numFalsePositive = 0; numFalseNegative = 0; } public Classification classify (Instance instance) { Classification c = underlyingClassifier.classify (instance); Classification cpc = confidencePredictingClassifier.classify (c); LabelVector lv = c.getLabelVector(); int bestIndex = lv.getBestIndex(); double [] values = new double[lv.numLocations()]; //// Put score of "correct" into score of the winning class... // xxx Can't set lv - it's immutable. // Must create copy and new classification object // lv.set (bestIndex, cpc.getLabelVector().value("correct")); //for (int i = 0; i < lv.numLocations(); i++) // if (i != bestIndex) // lv.set (i, 0.0); // Put score of "correct" in winning class and // set rest to 0 for (int i = 0; i < lv.numLocations(); i++) { if (i != bestIndex) values[i] = 0.0; else values[i] = cpc.getLabelVector().value("correct"); } //return c; if(c.bestLabelIsCorrect()){ numCorrectInstances++; totalCorrect+=cpc.getLabelVector().value("correct"); totalIncorrectCorrect+=cpc.getLabelVector().value("incorrect"); String correct = new String("correct"); if(correct.equals(cpc.getLabelVector().getBestLabel().toString())) numConfidenceCorrect++; else numFalseNegative++; } else{ numIncorrectInstances++; totalIncorrect+=cpc.getLabelVector().value("correct"); totalIncorrectIncorrect+=cpc.getLabelVector().value("incorrect"); if((new String("incorrect")).equals(cpc.getLabelVector().getBestLabel().toString())) numConfidenceCorrect++; else numFalsePositive++; } return new Classification(instance, this, new LabelVector(lv.getLabelAlphabet(), values)); // return cpc; } public void printAverageScores() { System.out.println("Mean score of correct for correct instances = " + meanCorrect()); System.out.println("Mean score of correct for incorrect instances = " + meanIncorrect()); System.out.println("Mean score of incorrect for correct instances = " + this.totalIncorrectCorrect/this.numCorrectInstances); System.out.println("Mean score of incorrect for incorrect instances = " + this.totalIncorrectIncorrect/this.numIncorrectInstances); } public void printConfidenceAccuracy() { System.out.println("Confidence predicting accuracy = " + ((double)numConfidenceCorrect/(numIncorrectInstances + numCorrectInstances))+ " false negatives: "+ numFalseNegative + "/"+numCorrectInstances + " false positives: "+ numFalsePositive +" / " +numIncorrectInstances); } public double meanCorrect() { if(this.numCorrectInstances==0) return 0.0; return (this.totalCorrect/(double)this.numCorrectInstances); } public double meanIncorrect() { if(this.numIncorrectInstances==0) return 0.0; return (this.totalIncorrect/(double)this.numIncorrectInstances); } }