package edu.stanford.nlp.semparse.open.dataset; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import edu.stanford.nlp.semparse.open.core.eval.CandidateStatistics; import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntityString; import edu.stanford.nlp.semparse.open.model.candidate.Candidate; import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; import edu.stanford.nlp.semparse.open.util.StringSampler; import fig.basic.LogInfo; import fig.basic.Option; public abstract class ExpectedAnswer { public static class Options { @Option public int logRewardVerbosity = 0; } public static Options opts = new Options(); public final List<TargetEntity> targetEntities; public ExpectedAnswer(TargetEntity... targetEntities) { this.targetEntities = Arrays.asList(targetEntities); } public ExpectedAnswer(List<TargetEntity> targetEntities) { this.targetEntities = new ArrayList<>(targetEntities); } public ExpectedAnswer(String... targetStrings) { this.targetEntities = new ArrayList<>(); for (String targetString : targetStrings) { this.targetEntities.add(new TargetEntityString(targetString)); } } public int size() { return targetEntities.size(); } // ============================================================ // Debug print entities // ============================================================ public String sampleEntities() { return StringSampler.sampleEntities(targetEntities, StringSampler.DEFAULT_LIMIT); } public String allEntities() { return StringSampler.sampleEntities(targetEntities); } // ============================================================ // Information Retrieval (Precision-Recall-F1) Scores // ============================================================ public IRScore getIRScore(Candidate candidate) { return getIRScore(candidate.predictedEntities); } abstract public IRScore getIRScore(List<String> predictedEntities); // ============================================================ // Reward function // ============================================================ /** * Use to control logging verbosity. * If logRewardVerbosity == 1, log the reward only when frozenReward is false * Otherwise, frozenReward is ignored: * - logRewardVerbosity < 1 : don't log * - logRewardVerbosity > 1 : always log */ public boolean frozenReward = false; /** * Compute the reward (in the range 0 - 1) */ public double reward(CandidateGroup group) { double reward = reward(group.predictedEntities); if (reward > 0 && (opts.logRewardVerbosity >= 2 || (opts.logRewardVerbosity >= 1 && !frozenReward))) { LogInfo.logs("reward = %s <<< %s", reward, group.sampleEntities()); } return reward; } /** * Compute the reward (in the range 0 - 1) */ public double reward(Candidate candidate) { double reward = reward(candidate.predictedEntities); if (reward > 0 && (opts.logRewardVerbosity >= 2 || (opts.logRewardVerbosity >= 1 && !frozenReward))) { LogInfo.logs("reward = %s <<< %s", reward, candidate.sampleEntities()); LogInfo.logs(" %s", candidate.pattern); } return reward; } /** * Compute the reward (in the range 0 - 1) */ abstract public double reward(List<String> predictedEntities); // ============================================================ // Count the number of correct entities // ============================================================ protected Map<List<String>, Integer> cachedCountCorrectEntities = new ConcurrentHashMap<>(); /** * Count the number of correct entities (cached version) */ public int countCorrectEntities(List<String> predictedEntities) { Integer count = cachedCountCorrectEntities.get(predictedEntities); if (count == null) { count = computeCountCorrectEntities(predictedEntities); cachedCountCorrectEntities.put(predictedEntities, count); } return count; } /** * Count the number of correct entities (cached version) */ public int countCorrectEntities(Candidate candidate) { return countCorrectEntities(candidate.predictedEntities); } /** * Count the number of correct entities (uncached version) */ public int countCorrectEntitiesNoCache(List<String> predictedEntities) { return computeCountCorrectEntities(predictedEntities); } /** * Count the number of correct entities (uncached version) */ public int countCorrectEntitiesNoCache(Candidate candidate) { return computeCountCorrectEntities(candidate.predictedEntities); } /** * Count the number of correct entities */ abstract public int computeCountCorrectEntities(List<String> predictedEntities); // ============================================================ // Check if the candidate is likely correct // ============================================================ /** * Return true if the candidate is probably the correct answer. * * Since the list of expected entities may be incomplete, we can only estimate * whether the given candidate is a correct one. */ public boolean isLikelyCorrect(Candidate candidate) { return isLikelyCorrect(candidate.predictedEntities); } /** * Return true if the candidate is probably the correct answer. * * Since the list of expected entities may be incomplete, we can only estimate * whether the given candidate is a correct one. */ abstract public boolean isLikelyCorrect(List<String> predictedEntities); /** * Find the FIRST likely correct candidate in the list. * Return null if no candidate is likely correct. */ public CandidateStatistics findFirstTrueCandidate(List<CandidateStatistics> rankedCandidateStats) { for (CandidateStatistics candidateStat : rankedCandidateStats) { if (isLikelyCorrect(candidateStat.candidate)) { return candidateStat; } } return null; } /** * Find the MOST likely correct candidate in the list. * Among the most likely correct candidates, choose the first one. * Return null if no candidate is likely correct. */ public abstract CandidateStatistics findBestCandidate(List<CandidateStatistics> rankedCandidateStats); }