package edu.stanford.nlp.semparse.open.dataset; import java.util.List; import edu.stanford.nlp.semparse.open.core.eval.CandidateStatistics; import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; import edu.stanford.nlp.semparse.open.util.BipartiteMatcher; import fig.basic.LogInfo; import fig.basic.Option; public class ExpectedAnswerInjectiveMatch extends ExpectedAnswer { public static class Options { @Option public double irThreshold = 0.8; @Option public String irCriterion = "recall"; } public static Options opts = new Options(); public ExpectedAnswerInjectiveMatch(TargetEntity... targetEntities) {super(targetEntities);} public ExpectedAnswerInjectiveMatch(List<TargetEntity> targetEntities) {super(targetEntities);} public ExpectedAnswerInjectiveMatch(String... targetStrings) {super(targetStrings);} @Override public IRScore getIRScore(List<String> predictedEntities) { return new IRScore(countCorrectEntities(predictedEntities), predictedEntities.size(), targetEntities.size()); } @Override public double reward(List<String> predictedEntities) { IRScore score = getIRScore(predictedEntities); double criterionScore = 0; switch (opts.irCriterion) { case "precision": case "p": criterionScore = score.precision; break; case "recall": case "r": criterionScore = score.recall; break; case "f1": criterionScore = score.f1; break; case "raw": return (score.numCorrect >= score.numGold - opts.irThreshold) ? 1 : 0; default: LogInfo.fails("IR Criterion %s not recognized", opts.irCriterion); } return (criterionScore < opts.irThreshold) ? 0 : criterionScore; } @Override public int computeCountCorrectEntities(List<String> predictedEntities) { return new BipartiteMatcher(targetEntities, predictedEntities).findMaximumMatch(); } @Override public boolean isLikelyCorrect(List<String> predictedEntities) { return reward(predictedEntities) > 0; } @Override public CandidateStatistics findBestCandidate(List<CandidateStatistics> rankedCandidateStats) { double bestReward = 0; CandidateStatistics best = null; for (CandidateStatistics candidateStat : rankedCandidateStats) { double reward = reward(candidateStat.candidate); if (reward > bestReward) { best = candidateStat; bestReward = reward; } } return best; } }