package edu.stanford.nlp.naturalli;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Quadruple;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/**
* An encapsulation of the natural logic weights to use during forward inference.
*
* @see edu.stanford.nlp.naturalli.ForwardEntailer
*
* @author Gabor Angeli
*/
public class NaturalLogicWeights {
private final Map<Pair<String, String>, Double> verbPPAffinity = new HashMap<>();
private final Map<Triple<String, String, String>, Double> verbSubjPPAffinity = new HashMap<>();
private final Map<Quadruple<String, String, String, String>, Double> verbSubjObjPPAffinity = new HashMap<>();
private final Map<Quadruple<String, String, String, String>, Double> verbSubjPPPPAffinity = new HashMap<>();
private final Map<Quadruple<String, String, String, String>, Double> verbSubjPPObjAffinity = new HashMap<>();
private final Map<String, Double> verbObjAffinity = new HashMap<>();
private final double upperProbabilityCap;
public NaturalLogicWeights() {
this.upperProbabilityCap = 1.0;
}
public NaturalLogicWeights(double upperProbabilityCap) {
this.upperProbabilityCap = upperProbabilityCap;
}
public NaturalLogicWeights(String affinityModels, double upperProbabilityCap) throws IOException {
this.upperProbabilityCap = upperProbabilityCap;
String line;
// Simple PP attachments
BufferedReader ppReader = IOUtils.readerFromString(affinityModels + "/pp.tab.gz", "utf8");
while ( (line = ppReader.readLine()) != null) {
String[] fields = line.split("\t");
Pair<String, String> key = Pair.makePair(fields[0].intern(), fields[1].intern());
verbPPAffinity.put(key, Double.parseDouble(fields[2]));
}
ppReader.close();
// Subj PP attachments
BufferedReader subjPPReader = IOUtils.readerFromString(affinityModels + "/subj_pp.tab.gz", "utf8");
while ( (line = subjPPReader.readLine()) != null) {
String[] fields = line.split("\t");
Triple<String, String, String> key = Triple.makeTriple(fields[0].intern(), fields[1].intern(), fields[2].intern());
verbSubjPPAffinity.put(key, Double.parseDouble(fields[3]));
}
subjPPReader.close();
// Subj Obj PP attachments
BufferedReader subjObjPPReader = IOUtils.readerFromString(affinityModels + "/subj_obj_pp.tab.gz", "utf8");
while ( (line = subjObjPPReader.readLine()) != null) {
String[] fields = line.split("\t");
Quadruple<String, String, String, String> key = Quadruple.makeQuadruple(fields[0].intern(), fields[1].intern(), fields[2].intern(), fields[3].intern());
verbSubjObjPPAffinity.put(key, Double.parseDouble(fields[4]));
}
subjObjPPReader.close();
// Subj PP PP attachments
BufferedReader subjPPPPReader = IOUtils.readerFromString(affinityModels + "/subj_pp_pp.tab.gz", "utf8");
while ( (line = subjPPPPReader.readLine()) != null) {
String[] fields = line.split("\t");
Quadruple<String, String, String, String> key = Quadruple.makeQuadruple(fields[0].intern(), fields[1].intern(), fields[2].intern(), fields[3].intern());
verbSubjPPPPAffinity.put(key, Double.parseDouble(fields[4]));
}
subjPPPPReader.close();
// Subj PP PP attachments
BufferedReader subjPPObjReader = IOUtils.readerFromString(affinityModels + "/subj_pp_obj.tab.gz", "utf8");
while ( (line = subjPPObjReader.readLine()) != null) {
String[] fields = line.split("\t");
Quadruple<String, String, String, String> key = Quadruple.makeQuadruple(fields[0].intern(), fields[1].intern(), fields[2].intern(), fields[3].intern());
verbSubjPPObjAffinity.put(key, Double.parseDouble(fields[4]));
}
subjPPObjReader.close();
// Subj PP PP attachments
BufferedReader objReader = IOUtils.readerFromString(affinityModels + "/obj.tab.gz", "utf8");
while ( (line = objReader.readLine()) != null) {
String[] fields = line.split("\t");
verbObjAffinity.put(fields[0], Double.parseDouble(fields[1]));
}
objReader.close();
}
public double deletionProbability(String edgeType) {
// TODO(gabor) this is effectively assuming hard NatLog weights
if (edgeType.contains("prep")) {
return 0.9;
} else if (edgeType.contains("obj")) {
return 0.0;
} else {
return 1.0;
}
}
public double subjDeletionProbability(SemanticGraphEdge edge, Iterable<SemanticGraphEdge> neighbors) {
// Get information about the neighbors
// (in a totally not-creepy-stalker sort of way)
for (SemanticGraphEdge neighbor : neighbors) {
if (neighbor != edge) {
String neighborRel = neighbor.getRelation().toString();
if (neighborRel.contains("subj")) {
return 1.0;
}
}
}
return 0.0;
}
public double objDeletionProbability(SemanticGraphEdge edge, Iterable<SemanticGraphEdge> neighbors) {
// Get information about the neighbors
// (in a totally not-creepy-stalker sort of way)
Optional<String> subj = Optional.empty();
Optional<String> pp = Optional.empty();
for (SemanticGraphEdge neighbor : neighbors) {
if (neighbor != edge) {
String neighborRel = neighbor.getRelation().toString();
if (neighborRel.contains("subj")) {
subj = Optional.of(neighbor.getDependent().originalText().toLowerCase());
}
if (neighborRel.contains("prep")) {
pp = Optional.of(neighborRel);
}
if (neighborRel.contains("obj")) {
return 1.0; // allow deleting second object
}
}
}
String obj = edge.getDependent().originalText().toLowerCase();
String verb = edge.getGovernor().originalText().toLowerCase();
// Compute the most informative drop probability we can
Double rawScore = null;
if (subj.isPresent()) {
if (pp.isPresent()) {
// Case: subj+obj
rawScore = verbSubjPPObjAffinity.get(Quadruple.makeQuadruple(verb, subj.get(), pp.get(), obj));
}
}
if (rawScore == null) {
rawScore = verbObjAffinity.get(verb);
}
if (rawScore == null) {
return deletionProbability(edge.getRelation().toString());
} else {
return 1.0 - Math.min(1.0, rawScore / upperProbabilityCap);
}
}
public double ppDeletionProbability(SemanticGraphEdge edge, Iterable<SemanticGraphEdge> neighbors) {
// Get information about the neighbors
// (in a totally not-creepy-stalker sort of way)
Optional<String> subj = Optional.empty();
Optional<String> obj = Optional.empty();
Optional<String> pp = Optional.empty();
for (SemanticGraphEdge neighbor : neighbors) {
if (neighbor != edge) {
String neighborRel = neighbor.getRelation().toString();
if (neighborRel.contains("subj")) {
subj = Optional.of(neighbor.getDependent().originalText().toLowerCase());
}
if (neighborRel.contains("obj")) {
obj = Optional.of(neighbor.getDependent().originalText().toLowerCase());
}
if (neighborRel.contains("prep")) {
pp = Optional.of(neighborRel);
}
}
}
String prep = edge.getRelation().toString();
String verb = edge.getGovernor().originalText().toLowerCase();
// Compute the most informative drop probability we can
Double rawScore = null;
if (subj.isPresent()) {
if (obj.isPresent()) {
// Case: subj+obj
rawScore = verbSubjObjPPAffinity.get(Quadruple.makeQuadruple(verb, subj.get(), obj.get(), prep));
}
if (rawScore == null && pp.isPresent()) {
// Case: subj+other_pp
rawScore = verbSubjPPPPAffinity.get(Quadruple.makeQuadruple(verb, subj.get(), pp.get(), prep));
}
if (rawScore == null) {
// Case: subj
rawScore = verbSubjPPAffinity.get(Triple.makeTriple(verb, subj.get(), prep));
}
}
if (rawScore == null) {
// Case: just the original pp
rawScore = verbPPAffinity.get(Pair.makePair(verb, prep));
}
if (rawScore == null) {
return deletionProbability(prep);
} else {
return 1.0 - Math.min(1.0, rawScore / upperProbabilityCap);
}
}
public double deletionProbability(SemanticGraphEdge edge, Iterable<SemanticGraphEdge> neighbors) {
String edgeRel = edge.getRelation().toString();
if (edgeRel.contains("prep")) {
return ppDeletionProbability(edge, neighbors);
} else if (edgeRel.contains("obj")) {
return objDeletionProbability(edge, neighbors);
} else if (edgeRel.contains("subj")) {
return subjDeletionProbability(edge, neighbors);
} else if (edgeRel.equals("amod")) {
String word = (edge.getDependent().lemma() != null ? edge.getDependent().lemma() : edge.getDependent().word()).toLowerCase();
if (Util.PRIVATIVE_ADJECTIVES.contains(word)) {
return 0.0;
} else {
return 1.0;
}
} else {
return deletionProbability(edgeRel);
}
}
/*
private double backoffEdgeProbability(String edgeRel) {
return 1.0; // TODO(gabor) should probably learn these...
}
public double deletionProbability(String parent, String edgeRel) {
return deletionProbability(parent, edgeRel, false);
}
public double deletionProbability(String parent, String edgeRel, boolean isSecondaryEdgeOfType) {
if (edgeRel.startsWith("prep")) {
double affinity = ppAffinity.getCount(parent, edgeRel);
if (affinity != 0.0 && !isSecondaryEdgeOfType) {
return Math.sqrt(1.0 - Math.min(1.0, affinity));
} else {
return backoffEdgeProbability(edgeRel);
}
} else if (edgeRel.startsWith("dobj")) {
double affinity = dobjAffinity.getCount(parent);
if (affinity != 0.0 && !isSecondaryEdgeOfType) {
return Math.sqrt(1.0 - Math.min(1.0, affinity));
} else {
return backoffEdgeProbability(edgeRel);
}
} else {
return backoffEdgeProbability(edgeRel);
}
}
*/
public static NaturalLogicWeights fromString(String str) {
return new NaturalLogicWeights(); // TODO(gabor)
}
}