package edu.stanford.nlp.stats;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.pipeline.LabeledChunkIdentifier;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.StringUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Properties;
/**
* Calculates phrase based precision and recall (similar to conlleval)
* Handles various encodings such as IO, IOB, IOE, BILOU, SBEIO, []
*
* Usage: java edu.stanford.nlp.stats.MultiClassChunkEvalStats [options] < filename <br>
* -r - Do raw token based evaluation <br>
* -d delimiter - Specifies delimiter to use (instead of tab) <br>
* -b boundary - Boundary token (default is -X- ) <br>
* -t defaultTag - Default tag to use if tag is not prefixed (i.e. is not X-xxx ) <br>
* -ignoreProvidedTag - Discards the provided tag (i.e. if label is X-xxx, just use xxx for evaluation)
*
* @author Angel Chang
*/
public class MultiClassChunkEvalStats extends MultiClassPrecisionRecallExtendedStats.MultiClassStringLabelStats {
/** A logger for this class */
private static final Redwood.RedwoodChannels log = Redwood.channels(MultiClassChunkEvalStats.class);
private boolean inCorrect = false;
private LabeledChunkIdentifier.LabelTagType prevCorrect = null;
private LabeledChunkIdentifier.LabelTagType prevGuess = null;
private LabeledChunkIdentifier chunker;
private boolean useLabel = false;
public <F> MultiClassChunkEvalStats(Classifier<String,F> classifier, GeneralDataset<String,F> data, String negLabel)
{
super(classifier, data, negLabel);
chunker = new LabeledChunkIdentifier();
chunker.setNegLabel(negLabel);
}
public MultiClassChunkEvalStats(String negLabel)
{
super(negLabel);
chunker = new LabeledChunkIdentifier();
chunker.setNegLabel(negLabel);
}
public MultiClassChunkEvalStats(Index<String> dataLabelIndex, String negLabel)
{
super(dataLabelIndex, negLabel);
chunker = new LabeledChunkIdentifier();
chunker.setNegLabel(negLabel);
}
public LabeledChunkIdentifier getChunker()
{
return chunker;
}
@Override
public void clearCounts()
{
super.clearCounts();
inCorrect = false;
prevCorrect = null;
prevGuess = null;
}
@Override
protected void finalizeCounts() {
markBoundary();
super.finalizeCounts();
}
private String getTypeLabel(LabeledChunkIdentifier.LabelTagType tagType) {
if (useLabel) return tagType.label;
else return tagType.type;
}
@Override
protected void markBoundary() {
if (inCorrect) {
inCorrect=false;
correctGuesses.incrementCount(getTypeLabel(prevCorrect));
}
prevGuess = null;
prevCorrect = null;
}
@Override
protected void addGuess(String guess, String trueLabel, boolean addUnknownLabels) {
LabeledChunkIdentifier.LabelTagType guessTagType = chunker.getTagType(guess);
LabeledChunkIdentifier.LabelTagType correctTagType = chunker.getTagType(trueLabel);
addGuess(guessTagType, correctTagType, addUnknownLabels);
}
protected void addGuess(LabeledChunkIdentifier.LabelTagType guess,
LabeledChunkIdentifier.LabelTagType correct, boolean addUnknownLabels) {
if (addUnknownLabels) {
if (labelIndex == null) {
labelIndex = new HashIndex<>();
}
labelIndex.add(getTypeLabel(guess));
labelIndex.add(getTypeLabel(correct));
}
if (inCorrect) {
boolean prevCorrectEnded = chunker.isEndOfChunk(prevCorrect, correct);
boolean prevGuessEnded = chunker.isEndOfChunk(prevGuess, guess);
if (prevCorrectEnded && prevGuessEnded && prevGuess.typeMatches(prevCorrect)) {
inCorrect=false;
correctGuesses.incrementCount(getTypeLabel(prevCorrect));
} else if (prevCorrectEnded != prevGuessEnded || !guess.typeMatches(correct)) {
inCorrect=false;
}
}
boolean correctStarted = LabeledChunkIdentifier.isStartOfChunk(prevCorrect, correct);
boolean guessStarted = LabeledChunkIdentifier.isStartOfChunk(prevGuess, guess);
if ( correctStarted && guessStarted && guess.typeMatches(correct)) {
inCorrect = true;
}
if ( correctStarted ) {
foundCorrect.incrementCount(getTypeLabel(correct));
}
if ( guessStarted ) {
foundGuessed.incrementCount(getTypeLabel(guess));
}
if (chunker.isIgnoreProvidedTag()) {
if (guess.typeMatches(correct)) {
tokensCorrect++;
}
} else {
if (guess.label.equals(correct.label)) {
tokensCorrect++;
}
}
tokensCount++;
prevGuess = guess;
prevCorrect = correct;
}
// Returns string precision recall in ConllEval format
@Override
public String getConllEvalString()
{
return getConllEvalString(true);
}
public static void main(String[] args) {
StringUtils.logInvocationString(log, args);
Properties props = StringUtils.argsToProperties(args);
String boundary = props.getProperty("b","-X-");
String delimiter = props.getProperty("d","\t");
String defaultPosTag = props.getProperty("t", "I");
boolean raw = Boolean.valueOf(props.getProperty("r","false"));
boolean ignoreProvidedTag = Boolean.valueOf(props.getProperty("ignoreProvidedTag","false"));
String format = props.getProperty("format", "conll");
String filename = props.getProperty("i");
String backgroundLabel = props.getProperty("k", "O");
try {
MultiClassPrecisionRecallExtendedStats stats;
if (raw) {
stats = new MultiClassStringLabelStats(backgroundLabel);
} else {
MultiClassChunkEvalStats mstats = new MultiClassChunkEvalStats(backgroundLabel);
mstats.getChunker().setDefaultPosTag(defaultPosTag);
mstats.getChunker().setIgnoreProvidedTag(ignoreProvidedTag);
stats = mstats;
}
if (filename != null) {
stats.score(filename, delimiter, boundary);
} else {
stats.score(new BufferedReader(new InputStreamReader(System.in)), delimiter, boundary);
}
if ("conll".equalsIgnoreCase(format)) {
System.out.println(stats.getConllEvalString());
} else {
System.out.println(stats.getDescription(6));
}
} catch (IOException ex) {
log.info("Error processing file: " + ex.toString());
ex.printStackTrace(System.err);
}
}
}