/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>
*/
package cc.mallet.fst.confidence;
import java.util.logging.*;
import java.util.ArrayList;
import cc.mallet.fst.*;
import cc.mallet.types.*;
import cc.mallet.util.MalletLogger;
/**
* Corrects a subset of the {@link Segment}s produced by a {@link
* Transducer}. It's most useful to find the {@link Segment}s that the
* {@link Transducer} is least confident in and correct those using
* the true {@link Labeling}
* (<code>correctLeastConfidenceSegments</code>). The corrected
* segment then propagates to other labelings in the sequence using
* "constrained viterbi" -- a viterbi calculation that requires the
* path to pass through the corrected segment states.
*/
public class ConstrainedViterbiTransducerCorrector implements TransducerCorrector
{
private static Logger logger = MalletLogger.getLogger(ConstrainedViterbiTransducerCorrector.class.getName());
TransducerConfidenceEstimator confidenceEstimator;
Transducer model;
ArrayList leastConfidentSegments;
public ConstrainedViterbiTransducerCorrector (TransducerConfidenceEstimator confidenceEstimator,
Transducer model) {
this.confidenceEstimator = confidenceEstimator;
this.model = model;
}
public ConstrainedViterbiTransducerCorrector (Transducer model) {
this (new ConstrainedForwardBackwardConfidenceEstimator (model), model);
}
public java.util.Vector getSegmentConfidences () {return confidenceEstimator.getSegmentConfidences();}
/**
Returns the least confident segments from each sequence in the
previous call to <code>correctLeastConfidentSegments</code>
*/
public ArrayList getLeastConfidentSegments () {
return this.leastConfidentSegments;
}
/**
Returns the least confident segments in <code>ilist</code>
@param ilist test instances
@param startTags indicate the beginning of segments
@param continueTages indicate "inside" of segments
@return list of {@link Segment}s, one for each instance, that is least confident
*/
public ArrayList getLeastConfidentSegments (InstanceList ilist, Object[] startTags, Object[] continueTags) {
ArrayList ret = new ArrayList ();
for (int i=0; i < ilist.size(); i++) {
Segment[] orderedSegments = confidenceEstimator.rankSegmentsByConfidence (
ilist.get (i), startTags, continueTags);
ret.add (orderedSegments[0]);
}
return ret;
}
public ArrayList correctLeastConfidentSegments (InstanceList ilist,
Object[] startTags,
Object[] continueTags) {
return correctLeastConfidentSegments (ilist, startTags, continueTags, false);
}
/**
Returns an ArrayList of corrected Sequences. Also stores
leastConfidentSegments, an ArrayList of the segments to correct,
where null entries mean no segment was corrected for that
sequence.
@param ilist test instances
@param startTags indicate the beginning of segments
@param continueTages indicate "inside" of segments
@param findIncorrect true if we should cycle through least
confident segments until find an incorrect one
@return list of {@link Sequence}s corresponding to the corrected
tagging of each instance in <code>ilist</code>
*/
public ArrayList correctLeastConfidentSegments (InstanceList ilist, Object[] startTags,
Object[] continueTags, boolean findIncorrect) {
ArrayList correctedPredictionList = new ArrayList ();
this.leastConfidentSegments = new ArrayList ();
logger.info (this.getClass().getName() + " ranking confidence using " +
confidenceEstimator.getClass().getName());
for (int i=0; i < ilist.size(); i++) {
logger.fine ("correcting instance# " + i + " / " + ilist.size());
Instance instance = ilist.get (i);
Segment[] orderedSegments = new Segment[1];
Sequence input = (Sequence) instance.getData ();
Sequence truth = (Sequence) instance.getTarget ();
Sequence predicted = new MaxLatticeDefault (model, input).bestOutputSequence();
int numIncorrect = 0;
for (int j=0; j < predicted.size(); j++)
numIncorrect += (!predicted.get(j).equals (truth.get(j))) ? 1 : 0;
if (numIncorrect == 0) { // nothing to correct
this.leastConfidentSegments.add (null);
correctedPredictionList.add (predicted);
continue;
}
// rank segments by confidence
orderedSegments = confidenceEstimator.rankSegmentsByConfidence (
instance, startTags, continueTags);
logger.fine ("Ordered Segments:\n");
for (int j=0; j < orderedSegments.length; j++) {
logger.fine (orderedSegments[j].toString());
}
logger.fine ("Correcting Segment: True Sequence:");
for (int j=0; j < truth.size(); j++)
logger.fine ((String)truth.get (j) + "\t");
logger.fine ("");
logger.fine ("Ordered Segments:\n");
for (int j=0; j < orderedSegments.length; j++) {
logger.fine (orderedSegments[j].toString());
}
// if <code>findIncorrect</code>, find the least confident
// segment that is incorrectly labeled
// else, use least confident segment
Segment leastConfidentSegment = orderedSegments[0];
if (findIncorrect) {
for (int j=0; j < orderedSegments.length; j++) {
if (!orderedSegments[j].correct()) {
leastConfidentSegment = orderedSegments[j];
break;
}
}
}
if (findIncorrect && leastConfidentSegment.correct()) {
logger.warning ("cannot find incorrect segment, probably because error is in background state\n");
this.leastConfidentSegments.add (null);
correctedPredictionList.add (predicted);
continue;
}
this.leastConfidentSegments.add (leastConfidentSegment);
if (leastConfidentSegment == null) { // nothing extracted
correctedPredictionList.add (predicted);
continue;
}
// create segmentCorrectedOutput, which has the true labels for
// the leastConfidentSegment and null for other positions
String[] sequence = new String[truth.size()];
int numCorrectedTokens = 0;
for (int j=0; j < sequence.length; j++)
sequence[j] = null;
for (int j=0; j < truth.size(); j++) {
// if in segment
if (leastConfidentSegment.indexInSegment (j)) {
sequence[j] = (String)truth.get (j);
numCorrectedTokens++;
}
}
if (leastConfidentSegment.endsPrematurely ()) {
sequence[leastConfidentSegment.getEnd()+1] =
(String)truth.get (leastConfidentSegment.getEnd()+1);
numCorrectedTokens++;
}
logger.fine ("Constrained Segment Sequence\n");
for (int j=0; j < sequence.length; j++) {
logger.fine (sequence[j]);
}
ArraySequence segmentCorrectedOutput = new ArraySequence (sequence);
// run constrained viterbi on this sequence with the
// constraint that this segment is tagged correctly
Sequence correctedPrediction = new MaxLatticeDefault (model,
orderedSegments[0].getInput (), segmentCorrectedOutput).bestOutputSequence();
int numIncorrectAfterCorrection = 0;
for (int j=0; j < truth.size(); j++)
numIncorrectAfterCorrection += (!correctedPrediction.get(j).equals (truth.get(j))) ? 1 : 0;
logger.fine ("Num incorrect tokens in original prediction: " + numIncorrect);
logger.fine ("Num corrected tokens: " + numCorrectedTokens);
logger.fine ("Num incorrect tokens after correction-propagation: " + numIncorrectAfterCorrection);
// print sequence info
logger.fine ("Correcting Segment: True Sequence:");
for (int j=0; j < truth.size(); j++)
logger.fine ((String)truth.get (j) + "\t");
logger.fine ("\nOriginal prediction: ");
for (int j=0; j < predicted.size(); j++)
logger.fine ((String)predicted.get (j) + "\t");
logger.fine ("\nCorrected prediction: ");
for (int j=0; j < correctedPrediction.size(); j++)
logger.fine ((String)correctedPrediction.get (j) + "\t");
logger.fine ("");
correctedPredictionList.add (correctedPrediction);
}
return correctedPredictionList;
}
}