package edu.stanford.nlp.semparse.open.dataset; import java.util.List; import edu.stanford.nlp.semparse.open.core.eval.CandidateStatistics; import fig.basic.Option; /** * Gives reward = 1 if the predicted entities match all criteria, and reward = 0 otherwise. */ public class ExpectedAnswerCriteriaMatch extends ExpectedAnswer { public static class Options { @Option(gloss = "Give partial reward for lists that don't exactly match the criteria") public boolean generous = false; } public static Options opts = new Options(); public final Criteria criteria; public ExpectedAnswerCriteriaMatch(Criteria criteria) { super(criteria.getTargetEntities()); this.criteria = criteria; } @Override public IRScore getIRScore(List<String> predictedEntities) { return criteria.getIRScore(predictedEntities); } @Override public double reward(List<String> predictedEntities) { if (!opts.generous) { return countCorrectEntities(predictedEntities) == criteria.numCriteria() ? 1 : 0; } else { // Generous reward double f1 = criteria.getIRScore(predictedEntities).f1; return f1 > ExpectedAnswerInjectiveMatch.opts.irThreshold ? f1 : 0; } } @Override public int computeCountCorrectEntities(List<String> predictedEntities) { return criteria.countMatchedCriteria(predictedEntities); } @Override public boolean isLikelyCorrect(List<String> predictedEntities) { return countCorrectEntities(predictedEntities) == criteria.numCriteria(); } @Override public CandidateStatistics findBestCandidate(List<CandidateStatistics> rankedCandidateStats) { double bestCorrectnessScore = 0; CandidateStatistics best = null; for (CandidateStatistics candidateStat : rankedCandidateStats) { double correctnessScore = criteria.getCorrectnessScore(candidateStat.candidate.predictedEntities); if (correctnessScore > bestCorrectnessScore) { best = candidateStat; bestCorrectnessScore = correctnessScore; } } return best; } }