package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.optimization.CmdEvaluator;
import edu.stanford.nlp.stats.MultiClassChunkEvalStats;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.*;
import java.util.Collection;
import java.util.List;
/**
* Evaluates a CRFClassifier on a set of data.
* This can be called by QNMinimizer periodically.
* If evalCmd is set, it runs the command line specified by evalCmd,
* otherwise it does evaluation internally.
* NOTE: when running conlleval with exec on Linux, linux will first
* fork process by duplicating memory of current process. So if the
* JVM has lots of memory, it will all be duplicated when
* child process is initially forked, which can be unfortunate.
*
* @author Angel Chang
*/
public class CRFClassifierEvaluator<IN extends CoreMap> extends CmdEvaluator {
/** A logger for this class */
private static final Redwood.RedwoodChannels log = Redwood.channels(CRFClassifierEvaluator.class);
private final CRFClassifier<IN> classifier;
/** NOTE: Default uses -r, specify without -r if IOB. */
private String cmdStr = "/u/nlp/bin/conlleval -r";
private String[] cmd;
// TODO: Use data structure to hold data + features
// Cache already featurized documents
// Original object bank
Collection<List<IN>> data;
// Featurized data
List<Triple<int[][][], int[], double[][][]>> featurizedData;
public CRFClassifierEvaluator(String description,
CRFClassifier<IN> classifier,
Collection<List<IN>> data,
List<Triple<int[][][], int[], double[][][]>> featurizedData) {
this.description = description;
this.classifier = classifier;
this.data = data;
this.featurizedData = featurizedData;
cmd = getCmd(cmdStr);
saveOutput = true;
}
public CRFClassifierEvaluator(String description,
CRFClassifier<IN> classifier) {
this.description = description;
this.classifier = classifier;
saveOutput = true;
}
/**
* Set the data to test on
*/
public void setTestData(Collection<List<IN>> data, List<Triple<int[][][], int[], double[][][]>> featurizedData) {
this.data = data;
this.featurizedData = featurizedData;
}
/**
* Set the evaluation command (set to null to skip evaluation using command line)
* @param evalCmd
*/
public void setEvalCmd(String evalCmd) {
log.info("setEvalCmd to " + evalCmd);
this.cmdStr = evalCmd;
if (cmdStr != null) {
cmdStr = cmdStr.trim();
if (cmdStr.isEmpty()) { cmdStr = null; }
}
cmd = getCmd(cmdStr);
}
@Override
public void setValues(double[] x)
{
classifier.updateWeightsForTest(x);
}
@Override
public String[] getCmd()
{
return cmd;
}
private double interpretCmdOutput() {
String output = getOutput();
String[] parts = output.split("\\s+");
int fScoreIndex = 0;
for (; fScoreIndex < parts.length; fScoreIndex++)
if (parts[fScoreIndex].equals("FB1:"))
break;
fScoreIndex += 1;
if (fScoreIndex < parts.length)
return Double.parseDouble(parts[fScoreIndex]);
else {
log.error("in CRFClassifierEvaluator.interpretCmdOutput(), cannot find FB1 score in output:\n"+output);
return -1;
}
}
@Override
public void outputToCmd(OutputStream outputStream)
{
try {
PrintWriter pw = IOUtils.encodedOutputStreamPrintWriter(outputStream, null, true);
classifier.classifyAndWriteAnswers(data, featurizedData, pw,
classifier.makeReaderAndWriter());
} catch (IOException ex) {
throw new RuntimeIOException(ex);
}
}
@Override
public double evaluate(double[] x) {
double score; // initialized below
setValues(x);
if (getCmd() != null) {
evaluateCmd(getCmd());
score = interpretCmdOutput();
} else {
try {
// TODO: Classify in memory instead of writing to tmp file
File f = File.createTempFile("CRFClassifierEvaluator","txt");
f.deleteOnExit();
OutputStream outputStream = new BufferedOutputStream(new FileOutputStream(f));
PrintWriter pw = IOUtils.encodedOutputStreamPrintWriter(outputStream, null, true);
classifier.classifyAndWriteAnswers(data, featurizedData, pw,
classifier.makeReaderAndWriter());
outputStream.close();
BufferedReader br = new BufferedReader(new FileReader(f));
MultiClassChunkEvalStats stats = new MultiClassChunkEvalStats("O");
score = stats.score(br, "\t");
log.info(stats.getConllEvalString());
f.delete();
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
return score;
}
}