package edu.stanford.nlp.parser.dvparser; import java.util.Collections; import java.util.IdentityHashMap; import java.util.List; import org.ejml.simple.SimpleMatrix; import edu.stanford.nlp.ling.HasWord; import edu.stanford.nlp.parser.lexparser.LexicalizedParser; import edu.stanford.nlp.parser.lexparser.Options; import edu.stanford.nlp.parser.lexparser.Reranker; import edu.stanford.nlp.parser.lexparser.RerankerQuery; import edu.stanford.nlp.parser.metrics.Eval; import edu.stanford.nlp.trees.DeepTree; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.trees.Trees; import edu.stanford.nlp.trees.TreeTransformer; import edu.stanford.nlp.util.Generics; public class DVModelReranker implements Reranker { private final Options op; private final DVModel model; public DVModelReranker(DVModel model) { this.op = model.op; this.model = model; } DVModel getModel() { return model; } public Query process(List<? extends HasWord> sentence) { return new Query(); } public List<Eval> getEvals() { Eval eval = new UnknownWordPrinter(model); return Collections.singletonList(eval); } public class Query implements RerankerQuery { private final TreeTransformer transformer; private final DVParserCostAndGradient scorer; private List<DeepTree> deepTrees; public Query() { this.transformer = LexicalizedParser.buildTrainTransformer(op); this.scorer = new DVParserCostAndGradient(null, null, model, op); this.deepTrees = Generics.newArrayList(); } public double score(Tree tree) { IdentityHashMap<Tree, SimpleMatrix> nodeVectors = Generics.newIdentityHashMap(); Tree transformedTree = transformer.transformTree(tree); if (op.trainOptions.useContextWords) { Trees.convertToCoreLabels(transformedTree); transformedTree.setSpans(); } double score = scorer.score(transformedTree, nodeVectors); deepTrees.add(new DeepTree(tree, nodeVectors, score)); return score; } public List<DeepTree> getDeepTrees() { return deepTrees; } } private static final long serialVersionUID = 7897546308624261207L; }