package edu.stanford.nlp.semparse.open.model.candidate; import java.util.*; import edu.stanford.nlp.semparse.open.dataset.Example; import edu.stanford.nlp.semparse.open.ling.AveragedWordVector; import edu.stanford.nlp.semparse.open.ling.LingUtils; import edu.stanford.nlp.semparse.open.model.FeatureVector; import edu.stanford.nlp.semparse.open.model.tree.KNode; import edu.stanford.nlp.semparse.open.util.StringSampler; import fig.basic.Option; /** * A CandidateGroup is a collection of candidates with the same selected KNodes * (and thus the same selected entity strings). */ public class CandidateGroup { public static class Options { @Option(gloss = "level of entity string normalization when creating candidate group " + "(0 = none / 1 = whitespace / 2 = simple / 3 = aggressive)") public int lateNormalizeEntities = 2; } public static Options opts = new Options(); public final Example ex; public final List<KNode> selectedNodes; public final List<String> predictedEntities; final List<Candidate> candidates; public FeatureVector features; public AveragedWordVector averagedWordVector; public CandidateGroup(Example ex, List<KNode> selectedNodes) { this.ex = ex; this.selectedNodes = new ArrayList<>(selectedNodes); List<String> entities = new ArrayList<>(); for (KNode node : selectedNodes) { entities.add(LingUtils.normalize(node.fullText, opts.lateNormalizeEntities)); } predictedEntities = new ArrayList<>(entities); candidates = new ArrayList<>(); } public void initAveragedWordVector() { if (averagedWordVector == null) averagedWordVector = new AveragedWordVector(predictedEntities); } public int numEntities() { return predictedEntities.size(); } public int numCandidate() { return candidates.size(); } public List<Candidate> getCandidates() { return Collections.unmodifiableList(candidates); } public Candidate addCandidate(TreePattern pattern) { return new Candidate(this, pattern); } public double getReward() { return ex.expectedAnswer.reward(this); } // ============================================================ // Debug Print // ============================================================ public String sampleEntities() { return StringSampler.sampleEntities(predictedEntities, StringSampler.DEFAULT_LIMIT); } public String allEntities() { return StringSampler.sampleEntities(predictedEntities); } }