package edu.stanford.nlp.parser.lexparser; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.text.NumberFormat; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.Triple; import edu.stanford.nlp.util.logging.Redwood; /** * A Dependency grammar that smoothes by averaging over similar words. * * @author Galen Andrew * @author Pi-Chuan Chang */ @SuppressWarnings("deprecation") public class ChineseSimWordAvgDepGrammar extends MLEDependencyGrammar { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(ChineseSimWordAvgDepGrammar.class); private static final long serialVersionUID = -1845503582705055342L; private static final double simSmooth = 10.0; private static final String argHeadFile = "simWords/ArgHead.5"; private static final String headArgFile = "simWords/HeadArg.5"; private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simArgMap; private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simHeadMap; private static final boolean debug = true; private static final boolean verbose = false; //private static final double MIN_PROBABILITY = Math.exp(-100.0); public ChineseSimWordAvgDepGrammar(TreebankLangParserParams tlpParams, boolean directional, boolean distance, boolean coarseDistance, boolean basicCategoryTagsInDependencyGrammar, Options op, Index<String> wordIndex, Index<String> tagIndex) { super(tlpParams, directional, distance, coarseDistance, basicCategoryTagsInDependencyGrammar, op, wordIndex, tagIndex); simHeadMap = getMap(headArgFile); simArgMap = getMap(argHeadFile); } public Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> getMap(String filename) { Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> hashMap = Generics.newHashMap(); try { BufferedReader wordMapBReader = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8")); String wordMapLine; Pattern linePattern = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)"); while ((wordMapLine = wordMapBReader.readLine()) != null) { Matcher m = linePattern.matcher(wordMapLine); if (!m.matches()) { log.info("Ill-formed line in similar word map file: " + wordMapLine); continue; } Pair<Integer, String> iTW = new Pair<>(wordIndex.addToIndex(m.group(1)), m.group(2)); double score = Double.parseDouble(m.group(5)); List<Triple<Integer, String, Double>> tripleList = hashMap.get(iTW); if (tripleList == null) { tripleList = new ArrayList<>(); hashMap.put(iTW, tripleList); } tripleList.add(new Triple<>(wordIndex.addToIndex(m.group(3)), m.group(4), score)); } } catch (IOException e) { throw new RuntimeException("Problem reading similar words file!"); } return hashMap; } @Override public double scoreTB(IntDependency dependency) { //return op.testOptions.depWeight * Math.log(probSimilarWordAvg(dependency)); return op.testOptions.depWeight * Math.log(probTBwithSimWords(dependency)); } public void setLex(Lexicon lex) { this.lex = lex; } private ClassicCounter<String> statsCounter = new ClassicCounter<>(); public void dumpSimWordAvgStats() { log.info("SimWordAvg stats:"); log.info(statsCounter); } /* ** An alternative kind of smoothing. ** The first one is "probSimilarWordAvg" implemented by Galen ** This one is trying to modify "probTB" in MLEDependencyGrammar using the simWords list we have ** -pichuan */ private double probTBwithSimWords(IntDependency dependency) { boolean leftHeaded = dependency.leftHeaded && directional; IntTaggedWord unknownHead = new IntTaggedWord(-1, dependency.head.tag); IntTaggedWord unknownArg = new IntTaggedWord(-1, dependency.arg.tag); if (verbose) { System.out.println("Generating " + dependency); } short distance = dependency.distance; // int hW = dependency.head.word; // int aW = dependency.arg.word; IntTaggedWord aTW = dependency.arg; // IntTaggedWord hTW = dependency.head; double pb_stop_hTWds = getStopProb(dependency); boolean isRoot = rootTW(dependency.head); if (dependency.arg.word == -2) { // did we generate stop? if (isRoot) { return 0.0; } return pb_stop_hTWds; } double pb_go_hTWds = 1.0 - pb_stop_hTWds; if (isRoot) { pb_go_hTWds = 1.0; } // generate the argument int valenceBinDistance = valenceBin(distance); // KEY: // c_ count of // p_ MLE prob of // pb_ MAP prob of // a arg // h head // T tag // W word // d direction // ds distance IntDependency temp = new IntDependency(dependency.head, dependency.arg, leftHeaded, valenceBinDistance); double c_aTW_hTWd = argCounter.getCount(temp); temp = new IntDependency(dependency.head, unknownArg, leftHeaded, valenceBinDistance); double c_aT_hTWd = argCounter.getCount(temp); temp = new IntDependency(dependency.head, wildTW, leftHeaded, valenceBinDistance); double c_hTWd = argCounter.getCount(temp); temp = new IntDependency(unknownHead, dependency.arg, leftHeaded, valenceBinDistance); double c_aTW_hTd = argCounter.getCount(temp); temp = new IntDependency(unknownHead, unknownArg, leftHeaded, valenceBinDistance); double c_aT_hTd = argCounter.getCount(temp); temp = new IntDependency(unknownHead, wildTW, leftHeaded, valenceBinDistance); double c_hTd = argCounter.getCount(temp); temp = new IntDependency(wildTW, dependency.arg, false, -1); double c_aTW = argCounter.getCount(temp); temp = new IntDependency(wildTW, unknownArg, false, -1); double c_aT = argCounter.getCount(temp); // do the magic double p_aTW_hTd = (c_hTd > 0.0 ? c_aTW_hTd / c_hTd : 0.0); double p_aT_hTd = (c_hTd > 0.0 ? c_aT_hTd / c_hTd : 0.0); double p_aTW_aT = (c_aTW > 0.0 ? c_aTW / c_aT : 1.0); double pb_aTW_hTWd; // = (c_aTW_hTWd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd); double pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd); double score; // = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds; /* smooth by simWords -pichuan */ List<Triple<Integer, String, Double>> sim2arg = simArgMap.get(new Pair<>(dependency.arg.word, stringBasicCategory(dependency.arg.tag))); List<Triple<Integer, String, Double>> sim2head = simHeadMap.get(new Pair<>(dependency.head.word, stringBasicCategory(dependency.head.tag))); List<Integer> simArg = new ArrayList<>(); List<Integer> simHead= new ArrayList<>(); if (sim2arg != null) { for (Triple<Integer,String,Double> t : sim2arg) { simArg.add(t.first); } } if (sim2head != null) { for (Triple<Integer,String,Double> t : sim2head) { simHead.add(t.first); } } double cSim_aTW_hTd = 0; double cSim_hTd = 0; for (int h : simHead) { IntTaggedWord hWord = new IntTaggedWord(h, dependency.head.tag); temp = new IntDependency(hWord, dependency.arg, dependency.leftHeaded, dependency.distance); cSim_aTW_hTd += argCounter.getCount(temp); temp = new IntDependency(hWord, wildTW, dependency.leftHeaded, dependency.distance); cSim_hTd += argCounter.getCount(temp); } double pSim_aTW_hTd = (cSim_hTd > 0.0 ? cSim_aTW_hTd / cSim_hTd : 0.0); // P(Wa,Ta|Th) if (debug) { //if (simHead.size() > 0 && cSim_hTd == 0.0) { if (pSim_aTW_hTd > 0.0) { //System.out.println("# simHead("+dependency.head.word+"-"+wordNumberer.object(dependency.head.word)+") =\t"+cSim_hTd); System.out.println(dependency+"\t"+pSim_aTW_hTd); //System.out.println(wordNumberer); } } //pb_aTW_hTWd = (c_aTW_hTWd + smooth_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd + smooth_aTW_hTWd); //if (pSim_aTW_hTd > 0.0) { double smoothSim_aTW_hTWd = 17.7; double smooth_aTW_hTWd = 17.7*2; //smooth_aTW_hTWd = smooth_aTW_hTWd*2; pb_aTW_hTWd = (c_aTW_hTWd + smoothSim_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smoothSim_aTW_hTWd + smooth_aTW_hTWd); System.out.println(dependency); System.out.println(c_aTW_hTWd+" + "+ smoothSim_aTW_hTWd+" * "+pSim_aTW_hTd+" + "+smooth_aTW_hTWd+" * "+p_aTW_hTd); System.out.println("-------------------------------- = "+pb_aTW_hTWd); System.out.println(c_hTWd+" + "+ smoothSim_aTW_hTWd+" + "+smooth_aTW_hTWd); System.out.println(); //} //pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd); score = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds; if (verbose) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(2); System.out.println(" c_aTW_hTWd: " + c_aTW_hTWd + "; c_aT_hTWd: " + c_aT_hTWd + "; c_hTWd: " + c_hTWd); System.out.println(" c_aTW_hTd: " + c_aTW_hTd + "; c_aT_hTd: " + c_aT_hTd + "; c_hTd: " + c_hTd); System.out.println(" Generated with pb_go_hTWds: " + nf.format(pb_go_hTWds) + " pb_aTW_hTWd: " + nf.format(pb_aTW_hTWd) + " p_aTW_aT: " + nf.format(p_aTW_aT) + " pb_aT_hTWd: " + nf.format(pb_aT_hTWd)); System.out.println(" NoDist score: " + score); } if (op.testOptions.prunePunc && pruneTW(aTW)) { return 1.0; } if (Double.isNaN(score)) { score = 0.0; } //if (op.testOptions.rightBonus && ! dependency.leftHeaded) // score -= 0.2; if (score < MIN_PROBABILITY) { score = 0.0; } return score; } private double probSimilarWordAvg(IntDependency dep) { double regProb = probTB(dep); statsCounter.incrementCount("total"); List<Triple<Integer, String, Double>> sim2arg = simArgMap.get(new Pair<>(dep.arg.word, stringBasicCategory(dep.arg.tag))); List<Triple<Integer, String, Double>> sim2head = simHeadMap.get(new Pair<>(dep.head.word, stringBasicCategory(dep.head.tag))); if (sim2head == null && sim2arg == null) { return regProb; } double sumScores = 0, sumWeights = 0; if (sim2head == null) { statsCounter.incrementCount("aSim"); for (Triple<Integer, String, Double> simArg : sim2arg) { //double weight = 1 - simArg.third; double weight = Math.exp(-50*simArg.third); for (int tag = 0, numT = tagIndex.size(); tag < numT; tag++) { if (!stringBasicCategory(tag).equals(simArg.second)) { continue; } IntTaggedWord tempArg = new IntTaggedWord(simArg.first, tag); IntDependency tempDep = new IntDependency(dep.head, tempArg, dep.leftHeaded, dep.distance); double probArg = Math.exp(lex.score(tempArg, 0, wordIndex.get(tempArg.word), null)); if (probArg == 0.0) { continue; } sumScores += probTB(tempDep) * weight / probArg; sumWeights += weight; } } } else if (sim2arg == null) { statsCounter.incrementCount("hSim"); for (Triple<Integer, String, Double> simHead : sim2head) { //double weight = 1 - simHead.third; double weight = Math.exp(-50*simHead.third); for (int tag = 0, numT = tagIndex.size(); tag < numT; tag++) { if (!stringBasicCategory(tag).equals(simHead.second)) { continue; } IntTaggedWord tempHead = new IntTaggedWord(simHead.first, tag); IntDependency tempDep = new IntDependency(tempHead, dep.arg, dep.leftHeaded, dep.distance); sumScores += probTB(tempDep) * weight; sumWeights += weight; } } } else { statsCounter.incrementCount("hSim"); statsCounter.incrementCount("aSim"); statsCounter.incrementCount("aSim&hSim"); for (Triple<Integer, String, Double> simArg : sim2arg) { for (int aTag = 0, numT = tagIndex.size(); aTag < numT; aTag++) { if (!stringBasicCategory(aTag).equals(simArg.second)) { continue; } IntTaggedWord tempArg = new IntTaggedWord(simArg.first, aTag); double probArg = Math.exp(lex.score(tempArg, 0, wordIndex.get(tempArg.word), null)); if (probArg == 0.0) { continue; } for (Triple<Integer, String, Double> simHead : sim2head) { for (int hTag = 0; hTag < numT; hTag++) { if (!stringBasicCategory(hTag).equals(simHead.second)) { continue; } IntTaggedWord tempHead = new IntTaggedWord(simHead.first, aTag); IntDependency tempDep = new IntDependency(tempHead, tempArg, dep.leftHeaded, dep.distance); //double weight = (1-simHead.third) * (1-simArg.third); double weight = Math.exp(-50*simHead.third) * Math.exp(-50*simArg.third); sumScores += probTB(tempDep) * weight / probArg; sumWeights += weight; } } } } } IntDependency temp = new IntDependency(dep.head, wildTW, dep.leftHeaded, dep.distance); double countHead = argCounter.getCount(temp); double simProb; if (sim2arg == null) { simProb = sumScores / sumWeights; } else { double probArg = Math.exp(lex.score(dep.arg, 0, wordIndex.get(dep.arg.word), null)); simProb = probArg * sumScores / sumWeights; } if (simProb == 0) { statsCounter.incrementCount("simProbZero"); } if (regProb == 0) { // log.info("zero reg prob"); statsCounter.incrementCount("regProbZero"); } double smoothProb = (countHead * regProb + simSmooth * simProb) / (countHead + simSmooth); if (smoothProb == 0) { // log.info("zero smooth prob"); statsCounter.incrementCount("smoothProbZero"); } return smoothProb; } private String stringBasicCategory(int tag) { return tlp.basicCategory(tagIndex.get(tag)); } }