/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> */ package cc.mallet.fst.confidence; import java.util.*; import java.io.*; import cc.mallet.fst.*; import cc.mallet.types.*; /** Calculates the effectiveness of "constrained viterbi" in propagating corrections in one segment of a sequence to other segments. */ public class ConfidenceCorrectorEvaluator { Object[] startTags; // to identify segment start/end boundaries Object[] inTags; public ConfidenceCorrectorEvaluator (Object[] startTags, Object[] inTags) { this.startTags = startTags; this.inTags = inTags; } /** Returns true if predSequence contains errors outside of correctedSegment. */ private boolean containsErrorInUncorrectedSegments (Sequence trueSequence, Sequence predSequence, Sequence correctedSequence, Segment correctedSegment) { for (int i=0; i < trueSequence.size(); i++) { if (correctedSegment.indexInSegment(i)) { if (!correctedSequence.get (i).equals (trueSequence.get (i))) { System.err.println ("\nTruth: "); for (int j=0; j < trueSequence.size(); j++) System.err.print (trueSequence.get (j) + " "); System.err.println ("\nPredicted: "); for (int j=0; j < trueSequence.size(); j++) System.err.print (predSequence.get (j) + " "); System.err.println ("\nCorrected: "); for (int j=0; j < trueSequence.size(); j++) System.err.print (correctedSequence.get (j) + " "); throw new IllegalStateException ("Corrected sequence does not have correct labels for corrected segment: " + correctedSegment); } } else { if (!predSequence.get (i).equals (trueSequence.get (i))) return true; } } return false; } /** Only evaluates over sequences which contain errors. Examine region not directly corrected by <code>correctedSegments </code> to measure effects of error propagation. @param model used to segment input sequence @param predictions list of the corrected segmentation @param ilist list of testing data @param correctedSegments list of {@link Segment}s in each sequence that were corrected...currently only allows one segment per instance. @param uncorrected true if we only evaluate sequences where errors remain after correction */ public void evaluate (Transducer model, ArrayList predictions, InstanceList ilist, ArrayList correctedSegments, String description, PrintStream outputStream, boolean errorsInUncorrected) { if (predictions.size() != ilist.size () || correctedSegments.size() != ilist.size ()) throw new IllegalArgumentException ("number of predicted sequence (" + predictions.size() + ") and number of corrected segments (" + correctedSegments.size() + ") must be equal to length of instancelist (" + ilist.size() + ")"); int numIncorrect2Correct = 0; // overall correction improvement int numCorrect2Incorrect = 0; // overall correction deprovement int numPropagatedIncorrect2Correct = 0; // count of propagated corrections int numPredictedCorrect = 0; // num tokens predicted correctly int numCorrectedCorrect = 0; // num tokens predicted correctly after correction // accuracy outside of corrected segment before and after propagation int numUncorrectedCorrectBeforePropagation = 0; int numUncorrectedCorrectAfterPropagation = 0; int totalTokens = 0; int totalTokensInUncorrectedRegion = 0; int numCorrectedSequences = 0; // count of sequences corrected for (int i=0; i < ilist.size(); i++) { Instance instance = ilist.get (i); Sequence input = (Sequence) instance.getData (); Sequence trueSequence = (Sequence) instance.getTarget (); Sequence predSequence = (Sequence) new MaxLatticeDefault (model, input).bestOutputSequence(); Sequence correctedSequence = (Sequence) predictions.get (i); Segment correctedSegment = (Segment) correctedSegments.get (i); // if any condition is true, do not evaluate this sequence if (correctedSegment == null || (errorsInUncorrected && !containsErrorInUncorrectedSegments ( trueSequence, predSequence, correctedSequence, correctedSegment))) continue; numCorrectedSequences++; totalTokens += trueSequence.size(); boolean[] predictedMatches = getMatches (trueSequence, predSequence); boolean[] correctedMatches = getMatches (trueSequence, correctedSequence); for (int j=0; j < predictedMatches.length; j++) { numPredictedCorrect += predictedMatches[j] ? 1 : 0; numCorrectedCorrect += correctedMatches[j] ? 1 : 0; if (predictedMatches[j] && !correctedMatches[j]) numCorrect2Incorrect++; else if (!predictedMatches[j] && correctedMatches[j]) numIncorrect2Correct++; // outside corrected segment if (j < correctedSegment.getStart() || j > correctedSegment.getEnd()) { totalTokensInUncorrectedRegion++; if (!predictedMatches[j] && correctedMatches[j]) numPropagatedIncorrect2Correct++; numUncorrectedCorrectBeforePropagation += predictedMatches[j] ? 1 : 0; numUncorrectedCorrectAfterPropagation += correctedMatches[j] ? 1 : 0; } } } double tokenAccuracyBeforeCorrection = (double)numPredictedCorrect / totalTokens; double tokenAccuracyAfterCorrection = (double)numCorrectedCorrect / totalTokens; double uncorrectedRegionAccuracyBeforeCorrection = (double)numUncorrectedCorrectBeforePropagation / totalTokensInUncorrectedRegion; double uncorrectedRegionAccuracyAfterCorrection = (double)numUncorrectedCorrectAfterPropagation / totalTokensInUncorrectedRegion; outputStream.println (description + "\nEvaluating effect of error-propagation in sequences containing at least one token error:" + "\ntotal number correctedsequences: " + numCorrectedSequences + "\ntotal number tokens: " + totalTokens + "\ntotal number tokens in \"uncorrected region\":" + totalTokensInUncorrectedRegion + "\ntotal number correct tokens before correction:" + numPredictedCorrect + "\ntotal number correct tokens after correction:" + numCorrectedCorrect + "\ntoken accuracy before correction: " + tokenAccuracyBeforeCorrection + "\ntoken accuracy after correction: " + tokenAccuracyAfterCorrection + "\nnumber tokens corrected by propagation: " + numPropagatedIncorrect2Correct + "\nnumber tokens made incorrect by propagation: " + numCorrect2Incorrect + "\ntoken accuracy of \"uncorrected region\" before propagation: " + uncorrectedRegionAccuracyBeforeCorrection + "\ntoken accuracy of \"uncorrected region\" after propagataion: " + uncorrectedRegionAccuracyAfterCorrection); } /** Returns a boolean array listing where two sequences have matching values. */ private boolean[] getMatches (Sequence s1, Sequence s2) { if (s1.size() != s2.size()) throw new IllegalArgumentException ("s1.size: " + s1.size() + " s2.size: " + s2.size()); boolean[] ret = new boolean [s1.size()]; for (int i=0; i < s1.size(); i++) ret[i] = s1.get (i).equals (s2.get(i)); return ret; } }