package org.wikibrain.sr.evaluation; import edu.emory.mathcs.backport.java.util.Collections; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.filefilter.DirectoryFileFilter; import org.apache.commons.lang3.StringUtils; import org.wikibrain.core.WikiBrainException; import org.wikibrain.core.dao.DaoException; import org.wikibrain.sr.dataset.Dataset; import java.io.File; import java.io.FileFilter; import java.io.IOException; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.regex.Matcher; import java.util.regex.Pattern; /** * * An evaluator for SR metrics. Writes a directory structure of evaluation results like: * * baseDir/local-similarity/ Or local-mostSimilar, universal-similarity, etc. * summary.tsv Tab separated spreadsheet of sr metric results * lang/split-group/run#-metric/ * overall.summary Human-readable summary of metric results * overall.log * splitname1.summary Human-readable summary of splitname1 within group * splitname2.summary * splitname3.summary * splitname1.log Log of results from splitname1 within group * splitname2.log * splitname3.log * splitname1.err Error logs for splitname1, within group * splitname2.err * splitname3.err * * @author Shilad Sen */ public abstract class Evaluator <T extends BaseEvaluationLog<T>> { private static final Object LOCK = new Object(); private static final Logger LOG = LoggerFactory.getLogger(Evaluator.class); private final File baseDir; private final String modeName; private final File modeDir; // if true, the id-based similarity and mostSimilar methods should be used. private boolean resolvePhrases = false; private boolean writeToStdout = true; private List<Split> splits = new ArrayList<Split>(); /** * @param baseDir baseDir in structure shown above * @param modeName "local-similarity", etc */ public Evaluator(File baseDir, String modeName) { this.baseDir = baseDir; this.modeName = modeName; this.modeDir = new File(baseDir, modeName); ensureIsDirectory(modeDir); } public void setWriteToStdout(boolean writeToStdout) { this.writeToStdout = writeToStdout; } public abstract void addCrossfolds(Dataset ds, int numFolds); /** * Adds a single split. * @param split */ public void addSplit(Split split) { this.splits.add(split); } /** * Creates a directory if it does not exist already * @param dirPath */ private void ensureIsDirectory(File dirPath) { if (!dirPath.isDirectory()) { FileUtils.deleteQuietly(dirPath); dirPath.mkdirs(); LOG.info("making " + dirPath); } } Pattern MATCH_RUN = Pattern.compile("^(\\d+)-.*"); /** * @return One more than the max run number across all modes, splits, and splits and metrics. */ private int getNextRunNumber() { int runNum = 0; FileFilter dirFilter = DirectoryFileFilter.INSTANCE; for (File modeFile : baseDir.listFiles(dirFilter)) { for (File langFile : modeFile.listFiles(dirFilter)) { for (File groupFile : langFile.listFiles(dirFilter)) { for (File runFile : groupFile.listFiles(dirFilter)) { String name = runFile.getName(); Matcher matcher = MATCH_RUN.matcher(name); if (matcher.matches()) { runNum = Math.max(runNum, Integer.valueOf(matcher.group(1)) + 1); } } } } } return runNum; } private File getLocalDir(Split split) { return FileUtils.getFile( modeDir, split.getTest().getLanguage().getLangCode(), split.getGroup()); } private File getLocalDir(Split split, int runNumber, String metricName) { return new File(getLocalDir(split), runNumber + "-" + metricName); } public abstract T createResults(File path) throws IOException; public abstract List<String> getSummaryFields(); public synchronized T evaluate(MonolingualSRFactory factory) throws IOException, DaoException, WikiBrainException { T overall = createResults(null); overall.setConfig("dataset", "overall"); String metricName; int runNumber; synchronized (LOCK) { runNumber = getNextRunNumber(); metricName = factory.getName(); for (Split split : splits) { ensureIsDirectory(getLocalDir(split, runNumber, metricName)); } } Map<String, T> groupEvals = new HashMap<String, T>(); for (Split split : splits) { T splitEval = evaluateSplitInternal(factory, split, runNumber); overall.merge(splitEval); if (!groupEvals.containsKey(split.getGroup())) { File gfile = new File(getLocalDir(split, runNumber, metricName), "overall.log"); groupEvals.put(split.getGroup(), createResults(gfile)); } groupEvals.get(split.getGroup()).merge(splitEval); IOUtils.closeQuietly(splitEval); } for (String group : groupEvals.keySet()) { Split gsplit = getSplitWithGroup(group); File gfile = getLocalDir(gsplit, runNumber, metricName); BaseEvaluationLog geval = groupEvals.get(group); geval.summarize(new File(gfile, "overall.summary")); maybeWriteToStdout("Split " + group + ", " + metricName + ", " + runNumber, geval); if (writeToStdout) geval.summarize(); updateOverallTsv(geval); IOUtils.closeQuietly(geval); } maybeWriteToStdout("Overall for run " + runNumber, overall); updateOverallTsv(overall); return overall; } private Split getSplitWithGroup(String group) { for (Split s : splits) { if (s.getGroup().equals(group)) { return s; } } return null; } /** * Updates the overall tsv file for a particular group * @param eval */ private void updateOverallTsv(BaseEvaluationLog eval) throws IOException { List<String> fields = getSummaryFields(); File tsv = FileUtils.getFile(modeDir, "summary.tsv"); String toWrite = ""; if (!tsv.isFile()) { toWrite += StringUtils.join(fields, "\t") + "\n"; } Map<String, String> summary = eval.getSummaryAsMap(); for (int i = 0; i < fields.size(); i++) { String field = fields.get(i); String value = summary.get(field); if (value == null) value = ""; if (i > 0) { toWrite += "\t"; } toWrite += value.replace('\t', ' '); } toWrite += "\n"; FileUtils.write(tsv, toWrite, true); } /** * Evaluates an sr metric against a single split and writes log, error, and summary files. * * * * @param factory * @param split * @param runNumber * @return * @throws IOException * @throws DaoException */ private T evaluateSplitInternal(MonolingualSRFactory factory, Split split, int runNumber) throws IOException, DaoException, WikiBrainException { File dir = getLocalDir(split, runNumber, factory.getName()); ensureIsDirectory(dir); File log = new File(dir, split.getName() + ".log"); File err = new File(dir, split.getName() + ".err"); File summary = new File(dir, split.getName() + ".summary"); Map<String, String> config = new LinkedHashMap<String, String>(); config.put("lang", split.getTest().getLanguage().getLangCode()); config.put("dataset", split.getGroup()); config.put("mode", modeName.toString().toLowerCase()); config.put("metricName", factory.getName()); config.put("runNumber", "" + runNumber); config.put("metricConfig", factory.describeMetric()); config.put("disambigConfig", factory.describeDisambiguator()); config.put("resolvePhrases", String.valueOf(resolvePhrases)); T splitEval = evaluateSplit(factory, split, log, err, config); splitEval.summarize(summary); maybeWriteToStdout( "Split " + modeName + ", " + split.getGroup() + ", " + split.getName() + ", " + factory.getName() + ", " + runNumber, splitEval); return splitEval; } protected abstract T evaluateSplit(MonolingualSRFactory factory, Split split, File log, File err, Map<String, String> conf) throws DaoException, IOException, WikiBrainException; private void maybeWriteToStdout(String caption, BaseEvaluationLog eval) throws IOException { if (!writeToStdout) { return; } System.out.println("Similarity evaluation for " + caption); eval.summarize(System.out); } public List<Split> getSplits() { return Collections.unmodifiableList(splits); } public void setResolvePhrases(boolean resolvePhrases) { this.resolvePhrases = resolvePhrases; } public boolean shouldResolvePhrases() { return resolvePhrases; } }