package edu.stanford.nlp.parser.lexparser; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.ling.HasTag; import edu.stanford.nlp.ling.HasWord; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.util.StringUtils; import edu.stanford.nlp.util.HashIndex; import edu.stanford.nlp.util.Index; import static edu.stanford.nlp.parser.lexparser.IntTaggedWord.ANY_WORD_INT; import static edu.stanford.nlp.parser.lexparser.IntTaggedWord.ANY_TAG_INT; import static edu.stanford.nlp.parser.lexparser.IntTaggedWord.STOP_WORD_INT; import static edu.stanford.nlp.parser.lexparser.IntTaggedWord.STOP_TAG_INT; import static edu.stanford.nlp.parser.lexparser.IntDependency.ANY_DISTANCE_INT; import java.io.*; import java.text.NumberFormat; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; public class MLEDependencyGrammar extends AbstractDependencyGrammar { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(MLEDependencyGrammar.class); final boolean useSmoothTagProjection; final boolean useUnigramWordSmoothing; static final boolean DEBUG = false; protected int numWordTokens; /** Stores all the counts for dependencies (with and without the word * being a wildcard) in the reduced tag space. */ protected ClassicCounter<IntDependency> argCounter; protected ClassicCounter<IntDependency> stopCounter; // reduced tag space /** Bayesian m-estimate prior for aT given hTWd against base distribution * of aT given hTd. * TODO: Note that these values are overwritten in the constructor. Find what is best and then maybe remove these defaults! */ public double smooth_aT_hTWd = 32.0; /** Bayesian m-estimate prior for aTW given hTWd against base distribution * of aTW given hTd. */ public double smooth_aTW_hTWd = 16.0; public double smooth_stop = 4.0; /** Interpolation between model that directly predicts aTW and model * that predicts aT and then aW given aT. This percent of the mass * is on the model directly predicting aTW. */ public double interp = 0.6; // public double distanceDecay = 0.0; // extra smoothing hyperparameters for tag projection backoff. Only used if useSmoothTagProjection is true. public double smooth_aTW_aT = 96.0; // back off Bayesian m-estimate of aTW given aT to aPTW given aPT public double smooth_aTW_hTd = 32.0; // back off Bayesian m-estimate of aTW_hTd to aPTW_hPTd (?? guessed, not tuned) public double smooth_aT_hTd = 32.0; // back off Bayesian m-estimate of aT_hTd to aPT_hPTd (?? guessed, not tuned) public double smooth_aPTW_aPT = 16.0; // back off word prediction from tag to projected tag (only used if useUnigramWordSmoothing is true) public MLEDependencyGrammar(TreebankLangParserParams tlpParams, boolean directional, boolean distance, boolean coarseDistance, boolean basicCategoryTagsInDependencyGrammar, Options op, Index<String> wordIndex, Index<String> tagIndex) { this(basicCategoryTagsInDependencyGrammar ? new BasicCategoryTagProjection(tlpParams.treebankLanguagePack()) : new TestTagProjection(), tlpParams, directional, distance, coarseDistance, op, wordIndex, tagIndex); } public MLEDependencyGrammar(TagProjection tagProjection, TreebankLangParserParams tlpParams, boolean directional, boolean useDistance, boolean useCoarseDistance, Options op, Index<String> wordIndex, Index<String> tagIndex) { super(tlpParams.treebankLanguagePack(), tagProjection, directional, useDistance, useCoarseDistance, op, wordIndex, tagIndex); useSmoothTagProjection = op.useSmoothTagProjection; useUnigramWordSmoothing = op.useUnigramWordSmoothing; argCounter = new ClassicCounter<>(); stopCounter = new ClassicCounter<>(); double[] smoothParams = tlpParams.MLEDependencyGrammarSmoothingParams(); smooth_aT_hTWd = smoothParams[0]; smooth_aTW_hTWd = smoothParams[1]; smooth_stop = smoothParams[2]; interp = smoothParams[3]; // cdm added Jan 2007 to play with dep grammar smoothing. Integrate this better if we keep it! smoothTP = new BasicCategoryTagProjection(tlpParams.treebankLanguagePack()); } @Override public String toString() { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(2); StringBuilder sb = new StringBuilder(2000); String cl = getClass().getName(); sb.append(cl.substring(cl.lastIndexOf('.') + 1)).append("[tagbins="); sb.append(numTagBins).append(",wordTokens=").append(numWordTokens).append("; head -> arg\n"); // for (Iterator dI = coreDependencies.keySet().iterator(); dI.hasNext();) { // IntDependency d = (IntDependency) dI.next(); // double count = coreDependencies.getCount(d); // sb.append(d + " count " + nf.format(count)); // if (dI.hasNext()) { // sb.append(","); // } // sb.append("\n"); // } sb.append("]"); return sb.toString(); } public boolean pruneTW(IntTaggedWord argTW) { String[] punctTags = tlp.punctuationTags(); for (String punctTag : punctTags) { if (argTW.tag == tagIndex.indexOf(punctTag)) { return true; } } return false; } static class EndHead { public int end; public int head; } /** Adds dependencies to list depList. These are in terms of the original * tag set not the reduced (projected) tag set. */ protected static EndHead treeToDependencyHelper(Tree tree, List<IntDependency> depList, int loc, Index<String> wordIndex, Index<String> tagIndex) { // try { // PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out,"GB18030"),true); // tree.pennPrint(pw); // } // catch (UnsupportedEncodingException e) {} if (tree.isLeaf() || tree.isPreTerminal()) { EndHead tempEndHead = new EndHead(); tempEndHead.head = loc; tempEndHead.end = loc + 1; return tempEndHead; } Tree[] kids = tree.children(); if (kids.length == 1) { return treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex); } EndHead tempEndHead = treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex); int lHead = tempEndHead.head; int split = tempEndHead.end; tempEndHead = treeToDependencyHelper(kids[1], depList, tempEndHead.end, wordIndex, tagIndex); int end = tempEndHead.end; int rHead = tempEndHead.head; String hTag = ((HasTag) tree.label()).tag(); String lTag = ((HasTag) kids[0].label()).tag(); String rTag = ((HasTag) kids[1].label()).tag(); String hWord = ((HasWord) tree.label()).word(); String lWord = ((HasWord) kids[0].label()).word(); String rWord = ((HasWord) kids[1].label()).word(); boolean leftHeaded = hWord.equals(lWord); String aTag = (leftHeaded ? rTag : lTag); String aWord = (leftHeaded ? rWord : lWord); int hT = tagIndex.indexOf(hTag); int aT = tagIndex.indexOf(aTag); int hW = (wordIndex.contains(hWord) ? wordIndex.indexOf(hWord) : wordIndex.indexOf(Lexicon.UNKNOWN_WORD)); int aW = (wordIndex.contains(aWord) ? wordIndex.indexOf(aWord) : wordIndex.indexOf(Lexicon.UNKNOWN_WORD)); int head = (leftHeaded ? lHead : rHead); int arg = (leftHeaded ? rHead : lHead); IntDependency dependency = new IntDependency(hW, hT, aW, aT, leftHeaded, (leftHeaded ? split - head - 1 : head - split)); depList.add(dependency); IntDependency stopL = new IntDependency(aW, aT, STOP_WORD_INT, STOP_TAG_INT, false, (leftHeaded ? arg - split : arg - loc)); depList.add(stopL); IntDependency stopR = new IntDependency(aW, aT, STOP_WORD_INT, STOP_TAG_INT, true, (leftHeaded ? end - arg - 1 : split - arg - 1)); depList.add(stopR); //System.out.println("Adding: "+dependency+" at "+tree.label()); tempEndHead.head = head; return tempEndHead; } public void dumpSizes() { // System.out.println("core dep " + coreDependencies.size()); System.out.println("arg counter " + argCounter.size()); System.out.println("stop counter " + stopCounter.size()); } /** Returns the List of dependencies for a binarized Tree. * In this tree, one of the two children always equals the head. * The dependencies are in terms of * the original tag set not the reduced (projected) tag set. * * @param tree A tree to be analyzed as dependencies * @return The list of dependencies in the tree (int format) */ public static List<IntDependency> treeToDependencyList(Tree tree, Index<String> wordIndex, Index<String> tagIndex) { List<IntDependency> depList = new ArrayList<>(); treeToDependencyHelper(tree, depList, 0, wordIndex, tagIndex); if (DEBUG) { System.out.println("----------------------------"); tree.pennPrint(); System.out.println(depList); } return depList; } public double scoreAll(Collection<IntDependency> deps) { double totalScore = 0.0; for (IntDependency d : deps) { //if (d.head.word == wordIndex.indexOf("via") || // d.arg.word == wordIndex.indexOf("via")) //System.out.println(d+" at "+score(d)); double score = score(d); if (score > Double.NEGATIVE_INFINITY) { totalScore += score; } } return totalScore; } /** Tune the smoothing and interpolation parameters of the dependency * grammar based on a tuning treebank. * * @param trees A Collection of Trees for setting parameters */ @Override public void tune(Collection<Tree> trees) { List<IntDependency> deps = new ArrayList<>(); for (Tree tree : trees) { deps.addAll(treeToDependencyList(tree, wordIndex, tagIndex)); } double bestScore = Double.NEGATIVE_INFINITY; double bestSmooth_stop = 0.0; double bestSmooth_aTW_hTWd = 0.0; double bestSmooth_aT_hTWd = 0.0; double bestInterp = 0.0; log.info("Tuning smooth_stop..."); for (smooth_stop = 1.0/100.0; smooth_stop < 100.0; smooth_stop *= 1.25) { double totalScore = 0.0; for (IntDependency dep : deps) { if (!rootTW(dep.head)) { double stopProb = getStopProb(dep); if (!dep.arg.equals(stopTW)) { stopProb = 1.0 - stopProb; } if (stopProb > 0.0) { totalScore += Math.log(stopProb); } } } if (totalScore > bestScore) { bestScore = totalScore; bestSmooth_stop = smooth_stop; } } smooth_stop = bestSmooth_stop; log.info("Tuning selected smooth_stop: " + smooth_stop); for (Iterator<IntDependency> iter = deps.iterator(); iter.hasNext(); ) { IntDependency dep = iter.next(); if (dep.arg.equals(stopTW)) { iter.remove(); } } log.info("Tuning other parameters..."); if ( ! useSmoothTagProjection) { bestScore = Double.NEGATIVE_INFINITY; for (smooth_aTW_hTWd = 0.5; smooth_aTW_hTWd < 100.0; smooth_aTW_hTWd *= 1.25) { log.info("."); for (smooth_aT_hTWd = 0.5; smooth_aT_hTWd < 100.0; smooth_aT_hTWd *= 1.25) { for (interp = 0.02; interp < 1.0; interp += 0.02) { double totalScore = 0.0; for (IntDependency dep : deps) { double score = score(dep); if (score > Double.NEGATIVE_INFINITY) { totalScore += score; } } if (totalScore > bestScore) { bestScore = totalScore; bestInterp = interp; bestSmooth_aTW_hTWd = smooth_aTW_hTWd; bestSmooth_aT_hTWd = smooth_aT_hTWd; log.info("Current best interp: " + interp + " with score " + totalScore); } } } } smooth_aTW_hTWd = bestSmooth_aTW_hTWd; smooth_aT_hTWd = bestSmooth_aT_hTWd; interp = bestInterp; } else { // for useSmoothTagProjection double bestSmooth_aTW_aT = 0.0; double bestSmooth_aTW_hTd = 0.0; double bestSmooth_aT_hTd = 0.0; bestScore = Double.NEGATIVE_INFINITY; for (smooth_aTW_hTWd = 1.125; smooth_aTW_hTWd < 100.0; smooth_aTW_hTWd *= 1.5) { log.info("#"); for (smooth_aT_hTWd = 1.125; smooth_aT_hTWd < 100.0; smooth_aT_hTWd *= 1.5) { log.info(":"); for (smooth_aTW_aT = 1.125; smooth_aTW_aT < 200.0; smooth_aTW_aT *= 1.5) { log.info("."); for (smooth_aTW_hTd = 1.125; smooth_aTW_hTd < 100.0; smooth_aTW_hTd *= 1.5) { for (smooth_aT_hTd = 1.125; smooth_aT_hTd < 100.0; smooth_aT_hTd *= 1.5) { for (interp = 0.2; interp <= 0.8; interp += 0.02) { double totalScore = 0.0; for (IntDependency dep : deps) { double score = score(dep); if (score > Double.NEGATIVE_INFINITY) { totalScore += score; } } if (totalScore > bestScore) { bestScore = totalScore; bestInterp = interp; bestSmooth_aTW_hTWd = smooth_aTW_hTWd; bestSmooth_aT_hTWd = smooth_aT_hTWd; bestSmooth_aTW_aT = smooth_aTW_aT; bestSmooth_aTW_hTd = smooth_aTW_hTd; bestSmooth_aT_hTd = smooth_aT_hTd; log.info("Current best interp: " + interp + " with score " + totalScore); } } } } } } log.info(); } smooth_aTW_hTWd = bestSmooth_aTW_hTWd; smooth_aT_hTWd = bestSmooth_aT_hTWd; smooth_aTW_aT = bestSmooth_aTW_aT; smooth_aTW_hTd = bestSmooth_aTW_hTd; smooth_aT_hTd = bestSmooth_aT_hTd; interp = bestInterp; } log.info("\nTuning selected smooth_aTW_hTWd: " + smooth_aTW_hTWd + " smooth_aT_hTWd: " + smooth_aT_hTWd + " interp: " + interp + " smooth_aTW_aT: " + smooth_aTW_aT + " smooth_aTW_hTd: " + smooth_aTW_hTd + " smooth_aT_hTd: " + smooth_aT_hTd); } /** Add this dependency with the given count to the grammar. * This is the main entry point of MLEDependencyGrammarExtractor. * This is a dependency represented in the full tag space. */ public void addRule(IntDependency dependency, double count) { if ( ! directional) { dependency = new IntDependency(dependency.head, dependency.arg, false, dependency.distance); } if (verbose) log.info("Adding dep " + dependency); // coreDependencies.incrementCount(dependency, count); /*new IntDependency(dependency.head.word, dependency.head.tag, dependency.arg.word, dependency.arg.tag, dependency.leftHeaded, dependency.distance), count); */ expandDependency(dependency, count); // log.info("stopCounter: " + stopCounter); // log.info("argCounter: " + argCounter); } /** The indices of this list are in the tag binned space. */ protected transient List<IntTaggedWord> tagITWList = null; //new ArrayList(); /** This maps from a tag to a cached IntTagWord that represents the * tag by having the wildcard word ANY_WORD_INT and the tag in the * reduced tag space. * The argument is in terms of the full tag space; internally this * function maps to the reduced space. * @param tag short representation of tag in full tag space * @return an IntTaggedWord in the reduced tag space */ private IntTaggedWord getCachedITW(short tag) { // The +2 below is because -1 and -2 are used with special meanings (see IntTaggedWord). if (tagITWList == null) { tagITWList = new ArrayList<>(numTagBins + 2); for (int i=0; i<numTagBins + 2; i++) { tagITWList.add(i, null); } } IntTaggedWord headT = tagITWList.get(tagBin(tag) + 2); if (headT == null) { headT = new IntTaggedWord(ANY_WORD_INT, tagBin(tag)); tagITWList.set(tagBin(tag) + 2, headT); } return headT; } /** The dependency arg is still in the full tag space. * * @param dependency An opbserved dependency * @param count The weight of the dependency */ protected void expandDependency(IntDependency dependency, double count) { //if (Test.prunePunc && pruneTW(dependency.arg)) // return; if (dependency.head == null || dependency.arg == null) { return; } if (dependency.arg.word != STOP_WORD_INT) { expandArg(dependency, valenceBin(dependency.distance), count); } expandStop(dependency, distanceBin(dependency.distance), count, true); } private TagProjection smoothTP; private Index<String> smoothTPIndex; private static final String TP_PREFIX = ".*TP*."; private short tagProject(short tag) { if (smoothTPIndex == null) { smoothTPIndex = new HashIndex<>(tagIndex); } if (tag < 0) { return tag; } else { String tagStr = smoothTPIndex.get(tag); String binStr = TP_PREFIX + smoothTP.project(tagStr); return (short) smoothTPIndex.addToIndex(binStr); } } /** Collect counts for a non-STOP dependent. * The dependency arg is still in the full tag space. * * @param dependency A non-stop dependency * @param valBinDist A binned distance * @param count The weight with which to add this dependency */ private void expandArg(IntDependency dependency, short valBinDist, double count) { IntTaggedWord headT = getCachedITW(dependency.head.tag); IntTaggedWord argT = getCachedITW(dependency.arg.tag); IntTaggedWord head = new IntTaggedWord(dependency.head.word, tagBin(dependency.head.tag)); //dependency.head; IntTaggedWord arg = new IntTaggedWord(dependency.arg.word, tagBin(dependency.arg.tag)); //dependency.arg; boolean leftHeaded = dependency.leftHeaded; // argCounter stores stuff in both the original and the reduced tag space??? argCounter.incrementCount(intern(head, arg, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headT, arg, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(head, argT, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headT, argT, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(head, wildTW, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headT, wildTW, leftHeaded, valBinDist), count); // the WILD head stats are always directionless and not useDistance! argCounter.incrementCount(intern(wildTW, arg, false, (short) -1), count); argCounter.incrementCount(intern(wildTW, argT, false, (short) -1), count); if (useSmoothTagProjection) { // added stuff to do more smoothing. CDM Jan 2007 IntTaggedWord headP = new IntTaggedWord(dependency.head.word, tagProject(dependency.head.tag)); IntTaggedWord headTP = new IntTaggedWord(ANY_WORD_INT, tagProject(dependency.head.tag)); IntTaggedWord argP = new IntTaggedWord(dependency.arg.word, tagProject(dependency.arg.tag)); IntTaggedWord argTP = new IntTaggedWord(ANY_WORD_INT, tagProject(dependency.arg.tag)); argCounter.incrementCount(intern(headP, argP, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headTP, argP, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headP, argTP, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headTP, argTP, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headP, wildTW, leftHeaded, valBinDist), count); argCounter.incrementCount(intern(headTP, wildTW, leftHeaded, valBinDist), count); // the WILD head stats are always directionless and not useDistance! argCounter.incrementCount(intern(wildTW, argP, false, (short) -1), count); argCounter.incrementCount(intern(wildTW, argTP, false, (short) -1), count); argCounter.incrementCount(intern(wildTW, new IntTaggedWord(dependency.head.word, ANY_TAG_INT), false, (short) -1), count); } numWordTokens++; } private void expandStop(IntDependency dependency, short distBinDist, double count, boolean wildForStop) { IntTaggedWord headT = getCachedITW(dependency.head.tag); IntTaggedWord head = new IntTaggedWord(dependency.head.word, tagBin(dependency.head.tag)); //dependency.head; IntTaggedWord arg = new IntTaggedWord(dependency.arg.word, tagBin(dependency.arg.tag));//dependency.arg; boolean leftHeaded = dependency.leftHeaded; if (arg.word == STOP_WORD_INT) { stopCounter.incrementCount(intern(head, arg, leftHeaded, distBinDist), count); stopCounter.incrementCount(intern(headT, arg, leftHeaded, distBinDist), count); } if (wildForStop || arg.word != STOP_WORD_INT) { stopCounter.incrementCount(intern(head, wildTW, leftHeaded, distBinDist), count); stopCounter.incrementCount(intern(headT, wildTW, leftHeaded, distBinDist), count); } } public double countHistory(IntDependency dependency) { IntDependency temp = new IntDependency(dependency.head.word, tagBin(dependency.head.tag), wildTW.word, wildTW.tag, dependency.leftHeaded, valenceBin(dependency.distance)); return argCounter.getCount(temp); } /** Score a tag binned dependency. */ public double scoreTB(IntDependency dependency) { return op.testOptions.depWeight * Math.log(probTB(dependency)); } private static final boolean verbose = false; protected static final double MIN_PROBABILITY = 1e-40; /** Calculate the probability of a dependency as a real probability between * 0 and 1 inclusive. * @param dependency The dependency for which the probability is to be * calculated. The tags in this dependency are in the reduced * TagProjection space. * @return The probability of the dependency */ protected double probTB(IntDependency dependency) { if (verbose) { // System.out.println("tagIndex: " + tagIndex); log.info("Generating " + dependency); } boolean leftHeaded = dependency.leftHeaded && directional; int hW = dependency.head.word; int aW = dependency.arg.word; short hT = dependency.head.tag; short aT = dependency.arg.tag; IntTaggedWord aTW = dependency.arg; IntTaggedWord hTW = dependency.head; boolean isRoot = rootTW(dependency.head); double pb_stop_hTWds; if (isRoot) { pb_stop_hTWds = 0.0; } else { pb_stop_hTWds = getStopProb(dependency); } if (dependency.arg.word == STOP_WORD_INT) { // did we generate stop? return pb_stop_hTWds; } double pb_go_hTWds = 1.0 - pb_stop_hTWds; // generate the argument short binDistance = valenceBin(dependency.distance); // KEY: // c_ count of (read as joint count of first and second) // p_ MLE prob of (or MAP if useSmoothTagProjection) // pb_ MAP prob of (read as prob of first given second thing) // a arg // h head // T tag // PT projected tag // W word // d direction // ds distance (implicit: there when direction is mentioned!) IntTaggedWord anyHead = new IntTaggedWord(ANY_WORD_INT, dependency.head.tag); IntTaggedWord anyArg = new IntTaggedWord(ANY_WORD_INT, dependency.arg.tag); IntTaggedWord anyTagArg = new IntTaggedWord(dependency.arg.word, ANY_TAG_INT); IntDependency temp = new IntDependency(dependency.head, dependency.arg, leftHeaded, binDistance); double c_aTW_hTWd = argCounter.getCount(temp); temp = new IntDependency(dependency.head, anyArg, leftHeaded, binDistance); double c_aT_hTWd = argCounter.getCount(temp); temp = new IntDependency(dependency.head, wildTW, leftHeaded, binDistance); double c_hTWd = argCounter.getCount(temp); temp = new IntDependency(anyHead, dependency.arg, leftHeaded, binDistance); double c_aTW_hTd = argCounter.getCount(temp); temp = new IntDependency(anyHead, anyArg, leftHeaded, binDistance); double c_aT_hTd = argCounter.getCount(temp); temp = new IntDependency(anyHead, wildTW, leftHeaded, binDistance); double c_hTd = argCounter.getCount(temp); // for smooth tag projection short aPT = Short.MIN_VALUE; double c_aPTW_hPTd = Double.NaN; double c_aPT_hPTd = Double.NaN; double c_hPTd = Double.NaN; double c_aPTW_aPT = Double.NaN; double c_aPT = Double.NaN; if (useSmoothTagProjection) { aPT = tagProject(dependency.arg.tag); short hPT = tagProject(dependency.head.tag); IntTaggedWord projectedArg = new IntTaggedWord(dependency.arg.word, aPT); IntTaggedWord projectedAnyHead = new IntTaggedWord(ANY_WORD_INT, hPT); IntTaggedWord projectedAnyArg = new IntTaggedWord(ANY_WORD_INT, aPT); temp = new IntDependency(projectedAnyHead, projectedArg, leftHeaded, binDistance); c_aPTW_hPTd = argCounter.getCount(temp); temp = new IntDependency(projectedAnyHead, projectedAnyArg, leftHeaded, binDistance); c_aPT_hPTd = argCounter.getCount(temp); temp = new IntDependency(projectedAnyHead, wildTW, leftHeaded, binDistance); c_hPTd = argCounter.getCount(temp); temp = new IntDependency(wildTW, projectedArg, false, ANY_DISTANCE_INT); c_aPTW_aPT = argCounter.getCount(temp); temp = new IntDependency(wildTW, projectedAnyArg, false, ANY_DISTANCE_INT); c_aPT = argCounter.getCount(temp); } // wild head is always directionless and no use distance temp = new IntDependency(wildTW, dependency.arg, false, ANY_DISTANCE_INT); double c_aTW = argCounter.getCount(temp); temp = new IntDependency(wildTW, anyArg, false, ANY_DISTANCE_INT); double c_aT = argCounter.getCount(temp); temp = new IntDependency(wildTW, anyTagArg, false, ANY_DISTANCE_INT); double c_aW = argCounter.getCount(temp); // do the Bayesian magic // MLE probs double p_aTW_hTd; double p_aT_hTd; double p_aTW_aT; double p_aW; double p_aPTW_aPT; double p_aPTW_hPTd; double p_aPT_hPTd; // backoffs either mle or themselves bayesian smoothed depending on useSmoothTagProjection if (useSmoothTagProjection) { if (useUnigramWordSmoothing) { p_aW = c_aW > 0.0 ? (c_aW / numWordTokens) : 1.0; // NEED this 1.0 for unknown words!!! p_aPTW_aPT = (c_aPTW_aPT + smooth_aPTW_aPT * p_aW) / (c_aPT + smooth_aPTW_aPT); } else { p_aPTW_aPT = c_aPTW_aPT > 0.0 ? (c_aPTW_aPT / c_aPT) : 1.0; // NEED this 1.0 for unknown words!!! } p_aTW_aT = (c_aTW + smooth_aTW_aT * p_aPTW_aPT) / (c_aT + smooth_aTW_aT); p_aPTW_hPTd = c_hPTd > 0.0 ? (c_aPTW_hPTd / c_hPTd): 0.0; p_aTW_hTd = (c_aTW_hTd + smooth_aTW_hTd * p_aPTW_hPTd) / (c_hTd + smooth_aTW_hTd); p_aPT_hPTd = c_hPTd > 0.0 ? (c_aPT_hPTd / c_hPTd) : 0.0; p_aT_hTd = (c_aT_hTd + smooth_aT_hTd * p_aPT_hPTd) / (c_hTd + smooth_aT_hTd); } else { // here word generation isn't smoothed - can't get previously unseen word with tag. Ugh. if (op.testOptions.useLexiconToScoreDependencyPwGt) { // We don't know the position. Now -1 means average over 0 and 1. p_aTW_aT = dependency.leftHeaded ? Math.exp(lex.score(dependency.arg, 1, wordIndex.get(dependency.arg.word), null)): Math.exp(lex.score(dependency.arg, -1, wordIndex.get(dependency.arg.word), null)); // double oldScore = c_aTW > 0.0 ? (c_aTW / c_aT) : 1.0; // if (oldScore == 1.0) { // log.info("#### arg=" + dependency.arg + " score=" + p_aTW_aT + // " oldScore=" + oldScore + " c_aTW=" + c_aTW + " c_aW=" + c_aW); // } } else { p_aTW_aT = c_aTW > 0.0 ? (c_aTW / c_aT) : 1.0; } p_aTW_hTd = c_hTd > 0.0 ? (c_aTW_hTd / c_hTd) : 0.0; p_aT_hTd = c_hTd > 0.0 ? (c_aT_hTd / c_hTd) : 0.0; } double pb_aTW_hTWd = (c_aTW_hTWd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd); double pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd); double score = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds; if (verbose) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(2); if (useSmoothTagProjection) { if (useUnigramWordSmoothing) { log.info(" c_aW=" + c_aW + ", numWordTokens=" + numWordTokens + ", p(aW)=" + nf.format(p_aW)); } log.info(" c_aPTW_aPT=" + c_aPTW_aPT + ", c_aPT=" + c_aPT + ", smooth_aPTW_aPT=" + smooth_aPTW_aPT + ", p(aPTW|aPT)=" + nf.format(p_aPTW_aPT)); } log.info(" c_aTW=" + c_aTW + ", c_aT=" + c_aT + ", smooth_aTW_aT=" + smooth_aTW_aT +", ## p(aTW|aT)=" + nf.format(p_aTW_aT)); if (useSmoothTagProjection) { log.info(" c_aPTW_hPTd=" + c_aPTW_hPTd + ", c_hPTd=" + c_hPTd + ", p(aPTW|hPTd)=" + nf.format(p_aPTW_hPTd)); } log.info(" c_aTW_hTd=" + c_aTW_hTd + ", c_hTd=" + c_hTd + ", smooth_aTW_hTd=" + smooth_aTW_hTd +", p(aTW|hTd)=" + nf.format(p_aTW_hTd)); if (useSmoothTagProjection) { log.info(" c_aPT_hPTd=" + c_aPT_hPTd + ", c_hPTd=" + c_hPTd + ", p(aPT|hPTd)=" + nf.format(p_aPT_hPTd)); } log.info(" c_aT_hTd=" + c_aT_hTd + ", c_hTd=" + c_hTd + ", smooth_aT_hTd=" + smooth_aT_hTd +", p(aT|hTd)=" + nf.format(p_aT_hTd)); log.info(" c_aTW_hTWd=" + c_aTW_hTWd + ", c_hTWd=" + c_hTWd + ", smooth_aTW_hTWd=" + smooth_aTW_hTWd +", ## p(aTW|hTWd)=" + nf.format(pb_aTW_hTWd)); log.info(" c_aT_hTWd=" + c_aT_hTWd + ", c_hTWd=" + c_hTWd + ", smooth_aT_hTWd=" + smooth_aT_hTWd +", ## p(aT|hTWd)=" + nf.format(pb_aT_hTWd)); log.info(" interp=" + interp + ", prescore=" + nf.format(interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) + ", P(go|hTWds)=" + nf.format(pb_go_hTWds) + ", score=" + nf.format(score)); } if (op.testOptions.prunePunc && pruneTW(aTW)) { return 1.0; } if (Double.isNaN(score)) { score = 0.0; } //if (op.testOptions.rightBonus && ! dependency.leftHeaded) // score -= 0.2; if (score < MIN_PROBABILITY) { score = 0.0; } return score; } /** Return the probability (as a real number between 0 and 1) of stopping * rather than generating another argument at this position. * @param dependency The dependency used as the basis for stopping on. * Tags are assumed to be in the TagProjection space. * @return The probability of generating this stop probability */ protected double getStopProb(IntDependency dependency) { short binDistance = distanceBin(dependency.distance); IntTaggedWord unknownHead = new IntTaggedWord(-1, dependency.head.tag); IntTaggedWord anyHead = new IntTaggedWord(ANY_WORD_INT, dependency.head.tag); IntDependency temp = new IntDependency(dependency.head, stopTW, dependency.leftHeaded, binDistance); double c_stop_hTWds = stopCounter.getCount(temp); temp = new IntDependency(unknownHead, stopTW, dependency.leftHeaded, binDistance); double c_stop_hTds = stopCounter.getCount(temp); temp = new IntDependency(dependency.head, wildTW, dependency.leftHeaded, binDistance); double c_hTWds = stopCounter.getCount(temp); temp = new IntDependency(anyHead, wildTW, dependency.leftHeaded, binDistance); double c_hTds = stopCounter.getCount(temp); double p_stop_hTds = (c_hTds > 0.0 ? c_stop_hTds / c_hTds : 1.0); double pb_stop_hTWds = (c_stop_hTWds + smooth_stop * p_stop_hTds) / (c_hTWds + smooth_stop); if (verbose) { System.out.println(" c_stop_hTWds: " + c_stop_hTWds + "; c_hTWds: " + c_hTWds + "; c_stop_hTds: " + c_stop_hTds + "; c_hTds: " + c_hTds); System.out.println(" Generate STOP prob: " + pb_stop_hTWds); } return pb_stop_hTWds; } private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { stream.defaultReadObject(); // log.info("Before decompression:"); // log.info("arg size: " + argCounter.size() + " total: " + argCounter.totalCount()); // log.info("stop size: " + stopCounter.size() + " total: " + stopCounter.totalCount()); ClassicCounter<IntDependency> compressedArgC = argCounter; argCounter = new ClassicCounter<>(); ClassicCounter<IntDependency> compressedStopC = stopCounter; stopCounter = new ClassicCounter<>(); for (IntDependency d : compressedArgC.keySet()) { double count = compressedArgC.getCount(d); expandArg(d, d.distance, count); } for (IntDependency d : compressedStopC.keySet()) { double count = compressedStopC.getCount(d); expandStop(d, d.distance, count, false); } // log.info("After decompression:"); // log.info("arg size: " + argCounter.size() + " total: " + argCounter.totalCount()); // log.info("stop size: " + stopCounter.size() + " total: " + stopCounter.totalCount()); expandDependencyMap = null; } private void writeObject(ObjectOutputStream stream) throws IOException { // log.info("\nBefore compression:"); // log.info("arg size: " + argCounter.size() + " total: " + argCounter.totalCount()); // log.info("stop size: " + stopCounter.size() + " total: " + stopCounter.totalCount()); ClassicCounter<IntDependency> fullArgCounter = argCounter; argCounter = new ClassicCounter<>(); for (IntDependency dependency : fullArgCounter.keySet()) { if (dependency.head != wildTW && dependency.arg != wildTW && dependency.head.word != -1 && dependency.arg.word != -1) { argCounter.incrementCount(dependency, fullArgCounter.getCount(dependency)); } } ClassicCounter<IntDependency> fullStopCounter = stopCounter; stopCounter = new ClassicCounter<>(); for (IntDependency dependency : fullStopCounter.keySet()) { if (dependency.head.word != -1) { stopCounter.incrementCount(dependency, fullStopCounter.getCount(dependency)); } } // log.info("After compression:"); // log.info("arg size: " + argCounter.size() + " total: " + argCounter.totalCount()); // log.info("stop size: " + stopCounter.size() + " total: " + stopCounter.totalCount()); stream.defaultWriteObject(); argCounter = fullArgCounter; stopCounter = fullStopCounter; } /** * Populates data in this DependencyGrammar from the character stream * given by the Reader r. */ @Override public void readData(BufferedReader in) throws IOException { final String LEFT = "left"; int lineNum = 1; // all lines have one rule per line boolean doingStop = false; for (String line = in.readLine(); line != null && line.length() > 0; line = in.readLine()) { try { if (line.equals("BEGIN_STOP")) { doingStop = true; continue; } String[] fields = StringUtils.splitOnCharWithQuoting(line, ' ', '\"', '\\'); // split on spaces, quote with doublequote, and escape with backslash // System.out.println("fields:\n" + fields[0] + "\n" + fields[1] + "\n" + fields[2] + "\n" + fields[3] + "\n" + fields[4] + "\n" + fields[5]); short distance = (short)Integer.parseInt(fields[4]); IntTaggedWord tempHead = new IntTaggedWord(fields[0], '/', wordIndex, tagIndex); IntTaggedWord tempArg = new IntTaggedWord(fields[2], '/', wordIndex, tagIndex); IntDependency tempDependency = new IntDependency(tempHead, tempArg, fields[3].equals(LEFT), distance); double count = Double.parseDouble(fields[5]); if (doingStop) { expandStop(tempDependency, distance, count, false); } else { expandArg(tempDependency, distance, count); } } catch (Exception e) { IOException ioe = new IOException("Error on line " + lineNum + ": " + line); ioe.initCause(e); throw ioe; } // System.out.println("read line " + lineNum + ": " + line); lineNum++; } } /** * Writes out data from this Object to the Writer w. */ @Override public void writeData(PrintWriter out) throws IOException { // all lines have one rule per line for (IntDependency dependency : argCounter.keySet()) { if (dependency.head != wildTW && dependency.arg != wildTW && dependency.head.word != -1 && dependency.arg.word != -1) { double count = argCounter.getCount(dependency); out.println(dependency.toString(wordIndex, tagIndex) + " " + count); } } out.println("BEGIN_STOP"); for (IntDependency dependency : stopCounter.keySet()) { if (dependency.head.word != -1) { double count = stopCounter.getCount(dependency); out.println(dependency.toString(wordIndex, tagIndex) + " " + count); } } out.flush(); } private static final long serialVersionUID = 1L; } // end class DependencyGrammar