/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>
*/
package cc.mallet.extract;
import java.io.*;
import java.util.*;
import cc.mallet.fst.*;
import cc.mallet.fst.confidence.*;
import cc.mallet.pipe.*;
import cc.mallet.types.*;
/**
* Estimates the confidence in the labeling of a LabeledSpan using a
* TransducerConfidenceEstimator.
*/
public class TransducerExtractionConfidenceEstimator extends ExtractionConfidenceEstimator implements Serializable
{
TransducerConfidenceEstimator confidenceEstimator;
Pipe featurePipe;
public TransducerExtractionConfidenceEstimator (TransducerConfidenceEstimator confidenceEstimator,
Object[] startTags,
Object[] continueTags,
Pipe featurePipe) {
super();
this.confidenceEstimator = confidenceEstimator;
this.featurePipe = featurePipe;
}
public void estimateConfidence (DocumentExtraction documentExtraction) {
Tokenization input = documentExtraction.getInput();
// WARNING: input Tokenization will likely already have many
// features appended from the last time it was passed through a
// featurePipe. To avoid a redundant calculation of features, the
// caller may want to set this.featurePipe =
// TokenSequence2FeatureVectorSequence
Instance carrier = this.featurePipe.pipe(new Instance(input, null, null, null));
Sequence pipedInput = (Sequence) carrier.getData();
Sequence prediction = documentExtraction.getPredictedLabels();
LabeledSpans labeledSpans = documentExtraction.getExtractedSpans();
SumLatticeDefault lattice = new SumLatticeDefault (this.confidenceEstimator.getTransducer(), pipedInput);
for (int i=0; i < labeledSpans.size(); i++) {
LabeledSpan span = labeledSpans.getLabeledSpan(i);
if (span.isBackground())
continue;
int[] segmentBoundaries = getSegmentBoundaries(input, span);
Segment segment = new Segment(pipedInput, prediction, prediction,
segmentBoundaries[0], segmentBoundaries[1],
null, null);
span.setConfidence(confidenceEstimator.estimateConfidenceFor(segment, lattice));
}
}
/** Convert the indices of a LabeledSpan into indices for a Tokenization.
* @return array of size two, where first index is start Token,
* second is end Token, inclusive */
private int[] getSegmentBoundaries (Tokenization tokens, LabeledSpan labeledSpan) {
int startCharIndex = labeledSpan.getStartIdx();
int endCharIndex = labeledSpan.getEndIdx()-1;
int[] ret = new int[]{-1,-1};
for (int i=0; i < tokens.size(); i++) {
int charIndex = tokens.getSpan(i).getStartIdx();
if (charIndex <= endCharIndex && charIndex >= startCharIndex) {
if (ret[0] == -1) {
ret[0] = i;
ret[1] = i;
}
else
ret[1] = i;
}
}
if (ret[0] == -1 || ret[1] == -1)
throw new IllegalArgumentException("Unable to find segment boundaries from span " + labeledSpan);
return ret;
}
}