package edu.berkeley.nlp.PCFGLA; import edu.berkeley.nlp.PCFGLA.smoothing.*; import edu.berkeley.nlp.math.SloppyMath; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.util.*; import edu.berkeley.nlp.util.PriorityQueue; import java.io.IOException; import java.io.PrintWriter; import java.io.Writer; import java.util.*; import opennlp.maxent.GIS; import opennlp.maxent.GISModel; /** * Simple implementation of a PCFG grammar, offering the ability to look up * rules by their child symbols. Rule probability estimates are just relative * frequency estimates off of training trees. */ public class Grammar implements java.io.Serializable { /** * @author leon * */ public static enum RandomInitializationType { INITIALIZE_WITH_SMALL_RANDOMIZATION, INITIALIZE_LIKE_MMT // initialize like in the Matzuyaki, Miyao, and Tsujii paper } public static class RuleNotFoundException extends Exception { private static final long serialVersionUID = 2L; } public int finalLevel; public boolean[] isGrammarTag; public boolean useEntropicPrior = false; private List<BinaryRule>[] binaryRulesWithParent; private List<BinaryRule>[] binaryRulesWithLC; private List<BinaryRule>[] binaryRulesWithRC; private BinaryRule[][] splitRulesWithLC; private BinaryRule[][] splitRulesWithRC; private BinaryRule[][] splitRulesWithP; public List<UnaryRule>[] unaryRulesWithParent; public List<UnaryRule>[] unaryRulesWithC; private List<UnaryRule>[] sumProductClosedUnaryRulesWithParent; /** the number of states */ public short numStates; /** the number of substates per state */ public short[] numSubStates; // private List<Rule> allRules; public Map<BinaryRule, BinaryRule> binaryRuleMap; BinaryRule bSearchRule; public Map<UnaryRule, UnaryRule> unaryRuleMap; UnaryRule uSearchRule; UnaryCounterTable unaryRuleCounter = null; BinaryCounterTable binaryRuleCounter = null; CounterMap<Integer, Integer> symbolCounter = new CounterMap<Integer, Integer>(); private static final long serialVersionUID = 1L; protected Numberer tagNumberer; public List<UnaryRule>[] closedSumRulesWithParent = null; public List<UnaryRule>[] closedSumRulesWithChild = null; public List<UnaryRule>[] closedViterbiRulesWithParent = null; public List<UnaryRule>[] closedViterbiRulesWithChild = null; public UnaryRule[][] closedSumRulesWithP = null; public UnaryRule[][] closedSumRulesWithC = null; public UnaryRule[][] closedViterbiRulesWithP = null; public UnaryRule[][] closedViterbiRulesWithC = null; private Map bestSumRulesUnderMax = null; private Map bestViterbiRulesUnderMax = null; public double threshold; public Smoother smoother = null; /** * A policy giving what state to go to next, starting from a given state, * going to a given state. This array is indexed by the start state, the end * state, the start substate, and the end substate. */ private int[][] closedViterbiPaths = null; private int[][] closedSumPaths = null; public boolean findClosedPaths; /** * If we are in logarithm mode, then this grammar's scores are all given as * logarithms. The default is to have a score plus a scale factor. */ boolean logarithmMode; public Tree<Short>[] splitTrees; public void clearUnaryIntermediates() { ArrayUtil.fill(closedSumPaths, 0); ArrayUtil.fill(closedViterbiPaths, 0); } public void addBinary(BinaryRule br) { // System.out.println("BG adding rule " + br); binaryRulesWithParent[br.parentState].add(br); binaryRulesWithLC[br.leftChildState].add(br); binaryRulesWithRC[br.rightChildState].add(br); // allRules.add(br); binaryRuleMap.put(br, br); } public void addUnary(UnaryRule ur) { // System.out.println(" UG adding rule " + ur); // closeRulesUnderMax(ur); if (!unaryRulesWithParent[ur.parentState].contains(ur)) { unaryRulesWithParent[ur.parentState].add(ur); unaryRulesWithC[ur.childState].add(ur); // allRules.add(ur); unaryRuleMap.put(ur, ur); } } public Numberer getTagNumberer() { return tagNumberer; } // @SuppressWarnings("unchecked") // public List<BinaryRule> getBinaryRulesByParent(int state) { // if (state >= binaryRulesWithParent.length) { // return Collections.EMPTY_LIST; // } // return binaryRulesWithParent[state]; // } // @SuppressWarnings("unchecked") public List<UnaryRule> getUnaryRulesByParent(int state) { if (state >= unaryRulesWithParent.length) { return Collections.EMPTY_LIST; } return unaryRulesWithParent[state]; } @SuppressWarnings("unchecked") public List<UnaryRule>[] getSumProductClosedUnaryRulesByParent() { return sumProductClosedUnaryRulesWithParent; } @SuppressWarnings("unchecked") public List<BinaryRule> getBinaryRulesByLeftChild(int state) { // System.out.println("getBinaryRulesByLeftChild not supported anymore."); // return null; if (state >= binaryRulesWithLC.length) { return Collections.EMPTY_LIST; } return binaryRulesWithLC[state]; } @SuppressWarnings("unchecked") public List<BinaryRule> getBinaryRulesByRightChild(int state) { // System.out.println("getBinaryRulesByRightChild not supported anymore."); // return null; if (state >= binaryRulesWithRC.length) { return Collections.EMPTY_LIST; } return binaryRulesWithRC[state]; } @SuppressWarnings("unchecked") public List<UnaryRule> getUnaryRulesByChild(int state) { // System.out.println("getUnaryRulesByChild not supported anymore."); // return null; if (state >= unaryRulesWithC.length) { return Collections.EMPTY_LIST; } return unaryRulesWithC[state]; } public String toString_old() { /* * StringBuilder sb = new StringBuilder(); List<String> ruleStrings = * new ArrayList<String>(); for (int state = 0; state < numStates; * state++) { List<BinaryRule> leftRules = * getBinaryRulesByLeftChild(state); for (BinaryRule r : leftRules) { * ruleStrings.add(r.toString()); } } for (int state = 0; state < * numStates; state++) { UnaryRule[] unaries = * getClosedViterbiUnaryRulesByChild(state); for (int r = 0; r < * unaries.length; r++) { UnaryRule ur = unaries[r]; * ruleStrings.add(ur.toString()); } } for (String ruleString : * CollectionUtils.sort(ruleStrings)) { sb.append(ruleString); * sb.append("\n"); } */ return null;// sb.toString(); } public void writeData(Writer w) throws IOException { finalLevel = (short) (Math.log(numSubStates[1]) / Math.log(2)); PrintWriter out = new PrintWriter(w); for (int state = 0; state < numStates; state++) { BinaryRule[] parentRules = this.splitRulesWithP(state); for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; out.print(r.toString()); } } for (int state = 0; state < numStates; state++) { UnaryRule[] unaries = this .getClosedViterbiUnaryRulesByParent(state); for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; out.print(ur.toString()); } } out.flush(); } public String toString() { // splitRules(); StringBuilder sb = new StringBuilder(); List<String> ruleStrings = new ArrayList<String>(); for (int state = 0; state < numStates; state++) { BinaryRule[] parentRules = this.splitRulesWithP(state); for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; ruleStrings.add(r.toString()); } } for (int state = 0; state < numStates; state++) { UnaryRule[] unaries = this.getClosedSumUnaryRulesByParent(state); // this.getClosedSumUnaryRulesByParent(state);// for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; ruleStrings.add(ur.toString()); } // UnaryRule[] unaries2 = // this.getClosedViterbiUnaryRulesByParent(state); // for (int r = 0; r < unaries2.length; r++) { // UnaryRule ur = unaries2[r]; // ruleStrings.add(ur.toString()); // } } for (String ruleString : CollectionUtils.sort(ruleStrings)) { sb.append(ruleString); // sb.append("\n"); } return sb.toString(); } public int getNumberOfRules() { int nRules = 0; for (int state = 0; state < numStates; state++) { BinaryRule[] parentRules = this.splitRulesWithP(state); for (int i = 0; i < parentRules.length; i++) { BinaryRule bRule = parentRules[i]; double[][][] scores = bRule.getScores2(); for (int j = 0; j < scores.length; j++) { for (int k = 0; k < scores[j].length; k++) { if (scores[j][k] != null) { nRules += scores[j][k].length; } } } } UnaryRule[] unaries = this.getClosedSumUnaryRulesByParent(state); for (int r = 0; r < unaries.length; r++) { UnaryRule uRule = unaries[r]; // List<UnaryRule> unaries = this.getUnaryRulesByParent(state); // for (UnaryRule uRule : unaries){ if (uRule.childState == uRule.parentState) continue; double[][] scores = uRule.getScores2(); for (int j = 0; j < scores.length; j++) { if (scores[j] != null) { nRules += scores[j].length; } } } } return nRules; } public void printUnaryRules() { // System.out.println("BY PARENT"); for (int state1 = 0; state1 < numStates; state1++) { List<UnaryRule> unaries = this.getUnaryRulesByParent(state1); for (UnaryRule uRule : unaries) { UnaryRule uRule2 = (UnaryRule) unaryRuleMap.get(uRule); if (!uRule.getScores2().equals(uRule2.getScores2())) System.out.print("BY PARENT:\n" + uRule + "" + uRule2 + "\n"); } } // System.out.println("VITERBI CLOSED"); for (int state1 = 0; state1 < numStates; state1++) { UnaryRule[] unaries = this .getClosedViterbiUnaryRulesByParent(state1); for (int r = 0; r < unaries.length; r++) { UnaryRule uRule = unaries[r]; // System.out.print(uRule); UnaryRule uRule2 = (UnaryRule) unaryRuleMap.get(uRule); if (unariesAreNotEqual(uRule, uRule2)) System.out.print("VITERBI CLOSED:\n" + uRule + "" + uRule2 + "\n"); } } /* * System.out.println("FROM RULE MAP"); for (UnaryRule uRule : * unaryRuleMap.keySet()){ System.out.print(uRule); } */ // System.out.println("AND NOW THE BINARIES"); // System.out.println("BY PARENT"); for (int state1 = 0; state1 < numStates; state1++) { BinaryRule[] parentRules = this.splitRulesWithP(state1); for (int i = 0; i < parentRules.length; i++) { BinaryRule bRule = parentRules[i]; BinaryRule bRule2 = (BinaryRule) binaryRuleMap.get(bRule); if (!bRule.getScores2().equals(bRule2.getScores2())) System.out.print("BINARY: " + bRule + "" + bRule2 + "\n"); } } /* * System.out.println("FROM RULE MAP"); for (BinaryRule bRule : * binaryRuleMap.keySet()){ System.out.print(bRule); } */ } public boolean unariesAreNotEqual(UnaryRule u1, UnaryRule u2) { // two cases: // 1. u2 is null and u1 is a selfRule if (u2 == null) { return false; /* * double[][] s1 = u1.getScores2(); for (int i=0; i<s1.length; i++){ * if (s1[i][i] != 1.0) return true; } */ } else { // compare all entries double[][] s1 = u1.getScores2(); double[][] s2 = u2.getScores2(); for (int i = 0; i < s1.length; i++) { if (s1[i] == null || s2[i] == null) continue; for (int j = 0; j < s1[i].length; j++) { if (s1[i][j] != s2[i][j]) return true; } } } return false; } @SuppressWarnings("unchecked") public void init() { binaryRuleMap = new HashMap<BinaryRule, BinaryRule>(); unaryRuleMap = new HashMap<UnaryRule, UnaryRule>(); // allRules = new ArrayList<Rule>(); bestSumRulesUnderMax = new HashMap(); bestViterbiRulesUnderMax = new HashMap(); binaryRulesWithParent = new List[numStates]; binaryRulesWithLC = new List[numStates]; binaryRulesWithRC = new List[numStates]; unaryRulesWithParent = new List[numStates]; unaryRulesWithC = new List[numStates]; closedSumRulesWithParent = new List[numStates]; closedSumRulesWithChild = new List[numStates]; closedViterbiRulesWithParent = new List[numStates]; closedViterbiRulesWithChild = new List[numStates]; isGrammarTag = new boolean[numStates]; // if (findClosedPaths) { closedViterbiPaths = new int[numStates][numStates]; // } closedSumPaths = new int[numStates][numStates]; for (short s = 0; s < numStates; s++) { binaryRulesWithParent[s] = new ArrayList<BinaryRule>(); binaryRulesWithLC[s] = new ArrayList<BinaryRule>(); binaryRulesWithRC[s] = new ArrayList<BinaryRule>(); unaryRulesWithParent[s] = new ArrayList<UnaryRule>(); unaryRulesWithC[s] = new ArrayList<UnaryRule>(); closedSumRulesWithParent[s] = new ArrayList<UnaryRule>(); closedSumRulesWithChild[s] = new ArrayList<UnaryRule>(); closedViterbiRulesWithParent[s] = new ArrayList<UnaryRule>(); closedViterbiRulesWithChild[s] = new ArrayList<UnaryRule>(); double[][] scores = new double[numSubStates[s]][numSubStates[s]]; for (int i = 0; i < scores.length; i++) { scores[i][i] = 1; } UnaryRule selfR = new UnaryRule(s, s, scores); // relaxSumRule(selfR); relaxViterbiRule(selfR); } } /** * Construct the optimal grammar from a list of Tree<StateSet>. This * assumes that trainTrees has all the inside/outside probabilities correct. * It performs the M step of EM--it finds the best grammar given the * sufficient statistics (i.e. i/o probabilities). * * @param trainTrees * A set of StateSet trees which have had their inside/outside * probabilities calculated. * @param dummy * A dummy parameter because otherwise java objects that this * method shares the signature of Grammar(List of Tree of * Strings). */ /* * comment out unused constructor public Grammar(List<Tree<StateSet>> * trainTrees, Grammar old_grammar) { this.tagNumberer = * Numberer.getGlobalNumberer("tags"); unaryRuleCounter = new * Counter<UnaryRule>(); binaryRuleCounter = new Counter<BinaryRule>(); * symbolCounter = new CounterMap<Integer,Integer>(); numStates = * tagNumberer.total(); numSubStates = old_grammar.numSubStates; init(); * * for (Tree<StateSet> trainTree : trainTrees) { * tallyStateSetTree(trainTree, old_grammar); } for (UnaryRule unaryRule : * unaryRuleCounter.keySet()) { double unaryProbability = * unaryRuleCounter.getCount(unaryRule) / * symbolCounter.getCount(unaryRule.getParentState * (),unaryRule.getParentSubState()); * unaryRule.setScore(Math.log(unaryProbability)); addUnary(unaryRule); } * for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double * binaryProbability = binaryRuleCounter.getCount(binaryRule) / * symbolCounter * .getCount(binaryRule.getParentState(),binaryRule.getParentSubState()); * binaryRule.setScore(Math.log(binaryProbability)); addBinary(binaryRule); * } } */ /** * This constructor generates a grammar with the rule probabilities read as * though there were no substates, but with a bit of randomness added. This * is the way we should initialize the EM algorithm. * * @param trainTrees * The training trees, which don't need to have their * inside-outside probabilities calculated correctly. * @param randomness * The size of the region to be uniformly sampled from in adding * extra random weight to the rules. */ /* * comment out unused constructor public Grammar(List<Tree<StateSet>> * trainTrees, int[] nSubStates, int maxN, double randomness) { * this.tagNumberer = Numberer.getGlobalNumberer("tags"); unaryRuleCounter = * new Counter<UnaryRule>(); binaryRuleCounter = new Counter<BinaryRule>(); * symbolCounter = new CounterMap<Integer, Integer>(); numStates = * tagNumberer.total(); numSubStates = nSubStates; maxNumSubStates = maxN; * init(); * * //tally trees as though there were no subsymbols for (Tree<StateSet> * trainTree : trainTrees) { tallyUninitializedStateSetTree(trainTree); } * //add randomness Random random = new Random(); for (UnaryRule unaryRule : * unaryRuleCounter.keySet()) { double r = random.nextDouble()*randomness; * unaryRuleCounter.incrementCount(unaryRule,r); } for (BinaryRule * binaryRule : binaryRuleCounter.keySet()) { double r = * random.nextDouble()*randomness; * binaryRuleCounter.incrementCount(binaryRule,r); } //re-tally the parent * counts because adding the randomness ruined them symbolCounter = new * CounterMap<Integer, Integer>(); for (UnaryRule unaryRule : * unaryRuleCounter.keySet()) { symbolCounter.incrementCount( * unaryRule.getParentState(), unaryRule.getParentSubState(), * unaryRuleCounter.getCount(unaryRule)); } for (BinaryRule binaryRule : * binaryRuleCounter.keySet()) { * symbolCounter.incrementCount(binaryRule.getParentState * (),binaryRule.getParentSubState(), * binaryRuleCounter.getCount(binaryRule)); } //set the scores of all the * rules based on these counts for (UnaryRule unaryRule : * unaryRuleCounter.keySet()) { double unaryProbability = * unaryRuleCounter.getCount(unaryRule) / * symbolCounter.getCount(unaryRule.getParentState(), * unaryRule.getParentSubState()); * unaryRule.setScore(Math.log(unaryProbability)); addUnary(unaryRule); } * for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double * binaryProbability = binaryRuleCounter.getCount(binaryRule) / * symbolCounter * .getCount(binaryRule.getParentState(),binaryRule.getParentSubState()); * binaryRule.setScore(Math.log(binaryProbability)); addBinary(binaryRule); * } } */ /** * Rather than calling some all-in-one constructor that takes a list of * trees as training data, you call Grammar() to create an empty grammar, * call tallyTree() repeatedly to include all the training data, then call * optimize() to take it into account. * * @param oldGrammar * This is the previous grammar. We use this to copy the split * trees that record how each state is split recursively. These * parameters are intialized if oldGrammar is null. */ @SuppressWarnings("unchecked") public Grammar(short[] nSubStates, boolean findClosedPaths, Smoother smoother, Grammar oldGrammar, double thresh) { this.tagNumberer = Numberer.getGlobalNumberer("tags"); this.findClosedPaths = findClosedPaths; this.smoother = smoother; this.threshold = thresh; unaryRuleCounter = new UnaryCounterTable(nSubStates); binaryRuleCounter = new BinaryCounterTable(nSubStates); symbolCounter = new CounterMap<Integer, Integer>(); numStates = (short) tagNumberer.total(); numSubStates = nSubStates; bSearchRule = new BinaryRule((short) 0, (short) 0, (short) 0); uSearchRule = new UnaryRule((short) 0, (short) 0); logarithmMode = false; if (oldGrammar != null) { splitTrees = oldGrammar.splitTrees; } else { splitTrees = new Tree[numStates]; boolean hasAnySplits = false; for (int tag = 0; !hasAnySplits && tag < numStates; tag++) { hasAnySplits = hasAnySplits || numSubStates[tag] > 1; } for (int tag = 0; tag < numStates; tag++) { ArrayList<Tree<Short>> children = new ArrayList<Tree<Short>>( numSubStates[tag]); if (hasAnySplits) { for (short substate = 0; substate < numSubStates[tag]; substate++) { children.add(substate, new Tree<Short>(substate)); } } splitTrees[tag] = new Tree<Short>((short) 0, children); } } init(); } public void setSmoother(Smoother smoother) { this.smoother = smoother; } public static double generateMMTRandomNumber(Random r) { double f = r.nextDouble(); f = f * 2 - 1; f = f * Math.log(3); return Math.exp(f); } public void optimize(double randomness) { // System.out.print("Optimizing Grammar..."); init(); // checkNumberOfSubstates(); if (randomness > 0.0) { Random random = GrammarTrainer.RANDOM; // switch (randomInitializationType ) { // case INITIALIZE_WITH_SMALL_RANDOMIZATION: // add randomness for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); for (int i = 0; i < unaryCounts.length; i++) { if (unaryCounts[i] == null) unaryCounts[i] = new double[numSubStates[unaryRule .getParentState()]]; for (int j = 0; j < unaryCounts[i].length; j++) { double r = random.nextDouble() * randomness; unaryCounts[i][j] += r; } } unaryRuleCounter.setCount(unaryRule, unaryCounts); } for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] binaryCounts = binaryRuleCounter .getCount(binaryRule); for (int i = 0; i < binaryCounts.length; i++) { for (int j = 0; j < binaryCounts[i].length; j++) { if (binaryCounts[i][j] == null) binaryCounts[i][j] = new double[numSubStates[binaryRule .getParentState()]]; for (int k = 0; k < binaryCounts[i][j].length; k++) { double r = random.nextDouble() * randomness; binaryCounts[i][j][k] += r; } } } binaryRuleCounter.setCount(binaryRule, binaryCounts); } // break; // case INITIALIZE_LIKE_MMT: // //multiply by a random factor // for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { // double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); // for (int i = 0; i < unaryCounts.length; i++) { // if (unaryCounts[i]==null) // continue; // for (int j = 0; j < unaryCounts[i].length; j++) { // double r = generateMMTRandomNumber(random); // unaryCounts[i][j] *= r; // } // } // unaryRuleCounter.setCount(unaryRule, unaryCounts); // } // for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { // double[][][] binaryCounts = // binaryRuleCounter.getCount(binaryRule); // for (int i = 0; i < binaryCounts.length; i++) { // for (int j = 0; j < binaryCounts[i].length; j++) { // if (binaryCounts[i][j]==null) // continue; // for (int k = 0; k < binaryCounts[i][j].length; k++) { // double r = generateMMTRandomNumber(random); // binaryCounts[i][j][k] *= r; // } // } // } // binaryRuleCounter.setCount(binaryRule, binaryCounts); // } // break; // } } // smooth // if (useEntropicPrior) { // System.out.println("\nGrammar uses entropic prior!"); // normalizeWithEntropicPrior(); // } //SSIE overwriteWithMaxent(); normalize(); smooth(false); // this also adds the rules to the proper arrays // System.out.println("done."); } public void removeUnlikelyRules(double thresh, double power) { // System.out.print("Removing everything below "+thresh+" and rasiing rules to the " // +power+"th power... "); if (isLogarithmMode()) power = Math.log(power); int total = 0, removed = 0; for (int state = 0; state < numStates; state++) { for (int r = 0; r < splitRulesWithP[state].length; r++) { BinaryRule rule = splitRulesWithP[state][r]; for (int lC = 0; lC < rule.scores.length; lC++) { for (int rC = 0; rC < rule.scores[lC].length; rC++) { if (rule.scores[lC][rC] == null) continue; boolean isNull = true; for (int p = 0; p < rule.scores[lC][rC].length; p++) { total++; if (rule.scores[lC][rC][p] < thresh) { // System.out.print("."); rule.scores[lC][rC][p] = 0; removed++; } else { if (power != 1) rule.scores[lC][rC][p] = Math.pow( rule.scores[lC][rC][p], power); isNull = false; } } if (isNull) rule.scores[lC][rC] = null; } } splitRulesWithP[state][r] = rule; } for (UnaryRule rule : unaryRulesWithParent[state]) { for (int c = 0; c < rule.scores.length; c++) { if (rule.scores[c] == null) continue; boolean isNull = true; for (int p = 0; p < rule.scores[c].length; p++) { total++; if (rule.scores[c][p] <= thresh) { removed++; rule.scores[c][p] = 0; } else { if (power != 1) rule.scores[c][p] = Math.pow(rule.scores[c][p], power); isNull = false; } } if (isNull) rule.scores[c] = null; } } } // System.out.print("done.\nRemoved "+removed+" out of "+total+" rules.\n"); } public void smooth(boolean noNormalize) { smoother.smooth(unaryRuleCounter, binaryRuleCounter); if (!noNormalize) normalize(); // if (threshold>0){ // removeUnlikelyRules(threshold); // normalize(); // } // compress and add the rules for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); for (int i = 0; i < unaryCounts.length; i++) { if (unaryCounts[i] == null) continue; /** * allZero records if all probabilities are 0. If so, we want to * null out the matrix element. */ double allZero = 0; int j = 0; while (allZero == 0 && j < unaryCounts[i].length) { allZero += unaryCounts[i][j++]; } if (allZero == 0) { unaryCounts[i] = null; } } unaryRule.setScores2(unaryCounts); addUnary(unaryRule); } computePairsOfUnaries(); for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule); for (int i = 0; i < binaryCounts.length; i++) { for (int j = 0; j < binaryCounts[i].length; j++) { if (binaryCounts[i][j] == null) continue; /** * allZero records if all probabilities are 0. If so, we * want to null out the matrix element. */ double allZero = 0; int k = 0; while (allZero == 0 && k < binaryCounts[i][j].length) { allZero += binaryCounts[i][j][k++]; } if (allZero == 0) { binaryCounts[i][j] = null; } } } binaryRule.setScores2(binaryCounts); addBinary(binaryRule); } // Reset all counters: unaryRuleCounter = new UnaryCounterTable(numSubStates); binaryRuleCounter = new BinaryCounterTable(numSubStates); symbolCounter = new CounterMap<Integer, Integer>(); /* * // tally usage of closed unary rule paths if (findClosedPaths) { int * maxSize = numStates * numStates; int size = 0; for (int i=0; * i<numStates; i++) { for (int j=0; j<numStates; j++) { if * (closedViterbiPaths[i][j]!=null) size++; } } * System.out.println("Closed viterbi unary path data structure covers " * + size + " / " + maxSize + " = " + (((double) size) / maxSize) + * " state pairs"); } */ // checkNumberOfSubstates(); // Romain: added the computation for the sum-product closure // TODO: fix the code and add this back in // sumProductClosedUnaryRulesWithParent = // sumProductUnaryClosure(unaryRulesWithParent); } public void clearCounts() { unaryRuleCounter = new UnaryCounterTable(numSubStates); binaryRuleCounter = new BinaryCounterTable(numSubStates); symbolCounter = new CounterMap<Integer, Integer>(); } /** * Normalize the unary & binary probabilities so that they sum to 1 for each * parent. The binaryRuleCounter and unaryRuleCounter are assumed to contain * probabilities, NOT log probabilities! */ public void normalize() { // tally the parent counts tallyParentCounts(); // turn the rule scores into fractions for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); int parentState = unaryRule.getParentState(); int nParentSubStates = numSubStates[parentState]; int nChildStates = numSubStates[unaryRule.childState]; double[] parentCount = new double[nParentSubStates]; for (int i = 0; i < nParentSubStates; i++) { parentCount[i] = symbolCounter.getCount(parentState, i); } boolean allZero = true; for (int j = 0; j < nChildStates; j++) { if (unaryCounts[j] == null) continue; for (int i = 0; i < nParentSubStates; i++) { if (parentCount[i] != 0) { double nVal = (unaryCounts[j][i] / parentCount[i]); if (nVal < threshold || SloppyMath.isVeryDangerous(nVal)) nVal = 0; unaryCounts[j][i] = nVal; } allZero = allZero && (unaryCounts[j][i] == 0); } } if (allZero) { System.out.println("Maybe an underflow? Rule: " + unaryRule + "\n" + ArrayUtil.toString(unaryCounts)); } unaryRuleCounter.setCount(unaryRule, unaryCounts); } for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule); int parentState = binaryRule.parentState; int nParentSubStates = numSubStates[parentState]; double[] parentCount = new double[nParentSubStates]; for (int i = 0; i < nParentSubStates; i++) { parentCount[i] = symbolCounter.getCount(parentState, i); } for (int j = 0; j < binaryCounts.length; j++) { for (int k = 0; k < binaryCounts[j].length; k++) { if (binaryCounts[j][k] == null) continue; for (int i = 0; i < nParentSubStates; i++) { if (parentCount[i] != 0) { double nVal = (binaryCounts[j][k][i] / parentCount[i]); if (nVal < threshold || SloppyMath.isVeryDangerous(nVal)) nVal = 0; binaryCounts[j][k][i] = nVal; } } } } binaryRuleCounter.setCount(binaryRule, binaryCounts); } } // public void normalizeWithEntropicPrior(){ // for (int iter=1; iter<=6; iter++){ // tallyParentCounts(); // // turn the rule scores into fractions // for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { // double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); // int parentState = unaryRule.getParentState(); // int nParentSubStates = numSubStates[parentState]; // double[] parentCount = new double[nParentSubStates]; // for (int i = 0; i < nParentSubStates; i++) { // parentCount[i] = symbolCounter.getCount(parentState, i); // } // double[][] theta = new double[unaryCounts.length][]; // if (iter==1){ // // initialize the thetas // for (int j = 0; j < unaryCounts.length; j++) { // if (unaryCounts[j]==null) continue; // theta[j] = new double[nParentSubStates]; // for (int i = 0; i < nParentSubStates; i++) { // double val = unaryCounts[j][i]; // theta[j][i]=Math.pow(val,1-(1/parentCount[i])); // } // } // unaryRule.setScores2(unaryCounts); // unaryRuleCounter.setCount(unaryRule,theta); // } // // compute lambdas // else{ // theta = unaryCounts; // unaryCounts = unaryRule.getScores2(); // for (int j = 0; j < unaryCounts.length; j++) { // if (unaryCounts[j]==null) continue; // for (int i = 0; i < nParentSubStates; i++) { // theta[j][i] = unaryCounts[j][i]/parentCount[i]; // } // } // for (int j = 0; j < unaryCounts.length; j++) { // if (unaryCounts[j]==null) continue; // for (int i = 0; i < nParentSubStates; i++) { // if (unaryCounts[j][i]==0) { // theta[j][i]=0; // continue; // } // double val = theta[j][i]; // double lambda = -((unaryCounts[j][i]/val) + Math.log(val) + 1); // // compute thetas // val = -1.0*unaryCounts[j][i]; // theta[j][i]= val/SloppyMath.lambert(val,(1+lambda)); // // if (SloppyMath.isDangerous(theta[j][i])) // // // System.out.println("Maybe an underflow: count "+val+" lambda "+lambda+" theta " // +theta[j][i] + " div "+SloppyMath.lambert(val,(1+lambda))); // } // } // unaryRuleCounter.setCount(unaryRule,theta); // } // } // for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { // double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule); // int parentState = binaryRule.parentState; // int nParentSubStates = numSubStates[parentState]; // double[] parentCount = new double[nParentSubStates]; // for (int i = 0; i < nParentSubStates; i++) { // parentCount[i] = symbolCounter.getCount(parentState, i); // } // double[][][] theta = new // double[binaryCounts.length][binaryCounts[0].length][]; // if (iter==1){ // // initialize the thetas // for (int j = 0; j < binaryCounts.length; j++) { // for (int k = 0; k < binaryCounts[j].length; k++) { // if (binaryCounts[j][k]==null) continue; // theta[j][k] = new double[nParentSubStates]; // for (int i = 0; i < nParentSubStates; i++) { // double val = binaryCounts[j][k][i]; // theta[j][k][i]=Math.pow(val,1-(1/parentCount[i])); // } // } // } // binaryRule.setScores2(binaryCounts); // binaryRuleCounter.setCount(binaryRule,theta); // } // else{ // theta = binaryCounts; // binaryCounts = binaryRule.getScores2(); // for (int j = 0; j < binaryCounts.length; j++) { // for (int k = 0; k < binaryCounts[j].length; k++) { // if (binaryCounts[j][k]==null) continue; // for (int i = 0; i < nParentSubStates; i++) { // theta[j][k][i] = binaryCounts[j][k][i]/parentCount[i]; // } // } // } // binaryCounts = binaryRule.getScores2(); // for (int j = 0; j < binaryCounts.length; j++) { // for (int k = 0; k < binaryCounts[j].length; k++) { // if (binaryCounts[j][k]==null) continue; // for (int i = 0; i < nParentSubStates; i++) { // if (binaryCounts[j][k][i]==0) { // theta[j][k][i]=0; // continue; // } // double val = theta[j][k][i]; // double lambda = -((binaryCounts[j][k][i]/val) + Math.log(val) + 1); // // compute thetas // val = -1.0*binaryCounts[j][k][i]; // theta[j][k][i]= val/SloppyMath.lambert(val,(1+lambda)); // if (SloppyMath.isDangerous(theta[j][k][i])) // System.out.println("Maybe an underflow: count "+val+" lambda "+lambda+" theta " // +theta[j][k][i] + " div "+SloppyMath.lambert(val,(1+lambda))); // if (val==0) theta[j][k][i]=0; // } // } // } // binaryRuleCounter.setCount(binaryRule,theta); // } // } // } // } /* * Check number of substates */ public void checkNumberOfSubstates() { for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); int nParentSubStates = numSubStates[unaryRule.parentState]; int nChildSubStates = numSubStates[unaryRule.childState]; if (unaryCounts.length != nChildSubStates) { System.out.println("Unary Rule " + unaryRule + " should have " + nChildSubStates + " childsubstates."); } if (unaryCounts[0] != null && unaryCounts[0].length != nParentSubStates) { System.out.println("Unary Rule " + unaryRule + " should have " + nParentSubStates + " parentsubstates."); } } for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule); int nParentSubStates = numSubStates[binaryRule.parentState]; int nLeftChildSubStates = numSubStates[binaryRule.leftChildState]; int nRightChildSubStates = numSubStates[binaryRule.rightChildState]; if (binaryCounts.length != nLeftChildSubStates) { System.out.println("Unary Rule " + binaryRule + " should have " + nLeftChildSubStates + " left childsubstates."); } if (binaryCounts[0].length != nRightChildSubStates) { System.out.println("Unary Rule " + binaryRule + " should have " + nRightChildSubStates + " right childsubstates."); } if (binaryCounts[0][0] != null && binaryCounts[0][0].length != nParentSubStates) { System.out.println("Unary Rule " + binaryRule + " should have " + nParentSubStates + " parentsubstates."); } } System.out.println("Done with checks."); } /** * Sum the parent symbol counter, symbolCounter. This is needed when the * rule counters are altered, such as when adding randomness in optimize(). * <p> * This assumes that the unaryRuleCounter and binaryRuleCounter contain * probabilities, NOT log probabilities! */ private void tallyParentCounts() { symbolCounter = new CounterMap<Integer, Integer>(); for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] unaryCounts = unaryRuleCounter.getCount(unaryRule); int parentState = unaryRule.getParentState(); isGrammarTag[parentState] = true; if (unaryRule.childState == parentState) continue; int nParentSubStates = numSubStates[parentState]; double[] sum = new double[nParentSubStates]; for (int j = 0; j < unaryCounts.length; j++) { if (unaryCounts[j] == null) continue; for (int i = 0; i < nParentSubStates; i++) { double val = unaryCounts[j][i]; // if (val>=threshold) sum[i] += val; } } for (int i = 0; i < nParentSubStates; i++) { symbolCounter.incrementCount(parentState, i, sum[i]); } } for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] binaryCounts = binaryRuleCounter.getCount(binaryRule); int parentState = binaryRule.parentState; isGrammarTag[parentState] = true; int nParentSubStates = numSubStates[parentState]; double[] sum = new double[nParentSubStates]; for (int j = 0; j < binaryCounts.length; j++) { for (int k = 0; k < binaryCounts[j].length; k++) { if (binaryCounts[j][k] == null) continue; for (int i = 0; i < nParentSubStates; i++) { double val = binaryCounts[j][k][i]; // if (val>=threshold) sum[i] += val; } } } for (int i = 0; i < nParentSubStates; i++) { symbolCounter.incrementCount(parentState, i, sum[i]); } } } public void tallyStateSetTree(Tree<StateSet> tree, Grammar old_grammar) { // Check that the top node is not split (it has only one substate) if (tree.isLeaf()) return; if (tree.isPreTerminal()) return; StateSet node = tree.getLabel(); if (node.numSubStates() != 1) { System.err.println("The top symbol is split!"); System.out.println(tree); System.exit(1); } // The inside score of its only substate is the (log) probability of the // tree double tree_score = node.getIScore(0); int tree_scale = node.getIScale(); if (tree_score == 0) { System.out .println("Something is wrong with this tree. I will skip it."); return; } tallyStateSetTree(tree, tree_score, tree_scale, old_grammar); } public void tallyStateSetTree(Tree<StateSet> tree, double tree_score, double tree_scale, Grammar old_grammar) { if (tree.isLeaf()) return; if (tree.isPreTerminal()) return; List<Tree<StateSet>> children = tree.getChildren(); StateSet parent = tree.getLabel(); short parentState = parent.getState(); int nParentSubStates = numSubStates[parentState]; switch (children.size()) { case 0: // This is a leaf (a preterminal node, if we count the words // themselves), // nothing to do break; case 1: StateSet child = children.get(0).getLabel(); short childState = child.getState(); int nChildSubStates = numSubStates[childState]; UnaryRule urule = new UnaryRule(parentState, childState); double[][] oldUScores = old_grammar.getUnaryScore(urule); // rule // score double[][] ucounts = unaryRuleCounter.getCount(urule); if (ucounts == null) ucounts = new double[nChildSubStates][]; double scalingFactor = ScalingTools.calcScaleFactor(parent .getOScale() + child.getIScale() - tree_scale); // if (scalingFactor==0){ // System.out.println("p: "+parent.getOScale()+" c: "+child.getIScale()+" t:"+tree_scale); // } for (short i = 0; i < nChildSubStates; i++) { if (oldUScores[i] == null) continue; double cIS = child.getIScore(i); if (cIS == 0) continue; if (ucounts[i] == null) ucounts[i] = new double[nParentSubStates]; for (short j = 0; j < nParentSubStates; j++) { double pOS = parent.getOScore(j); // Parent outside score if (pOS == 0) continue; double rS = oldUScores[i][j]; if (rS == 0) continue; if (tree_score == 0) tree_score = 1; double logRuleCount = (rS * cIS / tree_score) * scalingFactor * pOS; ucounts[i][j] += logRuleCount; } } // urule.setScores2(ucounts); unaryRuleCounter.setCount(urule, ucounts); break; case 2: StateSet leftChild = children.get(0).getLabel(); short lChildState = leftChild.getState(); StateSet rightChild = children.get(1).getLabel(); short rChildState = rightChild.getState(); int nLeftChildSubStates = numSubStates[lChildState]; int nRightChildSubStates = numSubStates[rChildState]; // new double[nLeftChildSubStates][nRightChildSubStates][]; BinaryRule brule = new BinaryRule(parentState, lChildState, rChildState); double[][][] oldBScores = old_grammar.getBinaryScore(brule); // rule // score if (oldBScores == null) { // rule was not in the grammar // parent.setIScores(iScores2); // break; oldBScores = new double[nLeftChildSubStates][nRightChildSubStates][nParentSubStates]; ArrayUtil.fill(oldBScores, 1.0); } double[][][] bcounts = binaryRuleCounter.getCount(brule); if (bcounts == null) bcounts = new double[nLeftChildSubStates][nRightChildSubStates][]; scalingFactor = ScalingTools.calcScaleFactor(parent.getOScale() + leftChild.getIScale() + rightChild.getIScale() - tree_scale); // if (scalingFactor==0){ // System.out.println("p: "+parent.getOScale()+" l: "+leftChild.getIScale()+" r:"+rightChild.getIScale()+" t:"+tree_scale); // } for (short i = 0; i < nLeftChildSubStates; i++) { double lcIS = leftChild.getIScore(i); if (lcIS == 0) continue; for (short j = 0; j < nRightChildSubStates; j++) { if (oldBScores[i][j] == null) continue; double rcIS = rightChild.getIScore(j); if (rcIS == 0) continue; // allocate parent array if (bcounts[i][j] == null) bcounts[i][j] = new double[nParentSubStates]; for (short k = 0; k < nParentSubStates; k++) { double pOS = parent.getOScore(k); // Parent outside // score if (pOS == 0) continue; double rS = oldBScores[i][j][k]; if (rS == 0) continue; if (tree_score == 0) tree_score = 1; double logRuleCount = (rS * lcIS / tree_score) * rcIS * scalingFactor * pOS; /* * if (logRuleCount == 0) { * System.out.println("rS "+rS+", lcIS " * +lcIS+", rcIS "+rcIS+", tree_score "+tree_score+ * ", scalingFactor "+scalingFactor+", pOS "+pOS); * System.out.println("Possibly underflow?"); // * logRuleCount = Double.MIN_VALUE; } */ bcounts[i][j][k] += logRuleCount; } } } binaryRuleCounter.setCount(brule, bcounts); break; default: throw new Error("Malformed tree: more than two children"); } for (Tree<StateSet> child : children) { tallyStateSetTree(child, tree_score, tree_scale, old_grammar); } } public void tallyUninitializedStateSetTree(Tree<StateSet> tree) { if (tree.isLeaf()) return; // the lexicon handles preterminal nodes if (tree.isPreTerminal()) return; List<Tree<StateSet>> children = tree.getChildren(); StateSet parent = tree.getLabel(); short parentState = parent.getState(); int nParentSubStates = parent.numSubStates(); // numSubStates[parentState]; switch (children.size()) { case 0: // This is a leaf (a preterminal node, if we count the words // themselves), nothing to do break; case 1: StateSet child = children.get(0).getLabel(); short childState = child.getState(); int nChildSubStates = child.numSubStates(); // numSubStates[childState]; double[][] counts = new double[nChildSubStates][nParentSubStates]; UnaryRule urule = new UnaryRule(parentState, childState, counts); unaryRuleCounter.incrementCount(urule, 1.0); break; case 2: StateSet leftChild = children.get(0).getLabel(); short lChildState = leftChild.getState(); StateSet rightChild = children.get(1).getLabel(); short rChildState = rightChild.getState(); int nLeftChildSubStates = leftChild.numSubStates(); // numSubStates[lChildState]; int nRightChildSubStates = rightChild.numSubStates();// numSubStates[rChildState]; double[][][] bcounts = new double[nLeftChildSubStates][nRightChildSubStates][nParentSubStates]; BinaryRule brule = new BinaryRule(parentState, lChildState, rChildState, bcounts); binaryRuleCounter.incrementCount(brule, 1.0); break; default: throw new Error("Malformed tree: more than two children"); } for (Tree<StateSet> child : children) { tallyUninitializedStateSetTree(child); } } /* * public void tallyChart(Pair<double[][][][], double[][][][]> chart, double * tree_score, Grammar old_grammar) { double[][][][] iScore = * chart.getFirst(); double[][][][] oScore = chart.getSecond(); if * (tree.isLeaf()) return; if (tree.isPreTerminal()) return; * List<Tree<StateSet>> children = tree.getChildren(); StateSet parent = * tree.getLabel(); short parentState = parent.getState(); int * nParentSubStates = numSubStates[parentState]; switch (children.size()) { * case 0: // This is a leaf (a preterminal node, if we count the words * themselves), // nothing to do break; case 1: StateSet child = * children.get(0).getLabel(); short childState = child.getState(); int * nChildSubStates = numSubStates[childState]; UnaryRule urule = new * UnaryRule(parentState, childState); double[][] oldUScores = * old_grammar.getUnaryScore(urule); // rule score double[][] ucounts = * unaryRuleCounter.getCount(urule); if (ucounts==null) ucounts = new * double[nChildSubStates][]; double scalingFactor = * Math.pow(GrammarTrainer.SCALE, * parent.getOScale()+child.getIScale()-tree_scale); if (scalingFactor==0){ * System * .out.println("p: "+parent.getOScale()+" c: "+child.getIScale()+" t:" * +tree_scale); } for (short i = 0; i < nChildSubStates; i++) { if * (oldUScores[i]==null) continue; double cIS = child.getIScore(i); if * (cIS==0) continue; if (ucounts[i]==null) ucounts[i] = new * double[nParentSubStates]; for (short j = 0; j < nParentSubStates; j++) { * double pOS = parent.getOScore(j); // Parent outside score if (pOS==0) * continue; double rS = oldUScores[i][j]; if (rS==0) continue; if * (tree_score==0) tree_score = 1; double logRuleCount = (rS * cIS / * tree_score) * scalingFactor * pOS; ucounts[i][j] += logRuleCount; } } * //urule.setScores2(ucounts); unaryRuleCounter.setCount(urule, ucounts); * break; case 2: StateSet leftChild = children.get(0).getLabel(); short * lChildState = leftChild.getState(); StateSet rightChild = * children.get(1).getLabel(); short rChildState = rightChild.getState(); * int nLeftChildSubStates = numSubStates[lChildState]; int * nRightChildSubStates = numSubStates[rChildState]; //new * double[nLeftChildSubStates][nRightChildSubStates][]; BinaryRule brule = * new BinaryRule(parentState, lChildState, rChildState); double[][][] * oldBScores = old_grammar.getBinaryScore(brule); // rule score if * (oldBScores==null){ //rule was not in the grammar * //parent.setIScores(iScores2); //break; oldBScores=new * double[nLeftChildSubStates][nRightChildSubStates][nParentSubStates]; * ArrayUtil.fill(oldBScores,1.0); } double[][][] bcounts = * binaryRuleCounter.getCount(brule); if (bcounts==null) bcounts = new * double[nLeftChildSubStates][nRightChildSubStates][]; scalingFactor = * Math.pow(GrammarTrainer.SCALE, * parent.getOScale()+leftChild.getIScale()+rightChild * .getIScale()-tree_scale); if (scalingFactor==0){ * System.out.println("p: "+ * parent.getOScale()+" l: "+leftChild.getIScale()+" r:" * +rightChild.getIScale()+" t:"+tree_scale); } for (short i = 0; i < * nLeftChildSubStates; i++) { double lcIS = leftChild.getIScore(i); if * (lcIS==0) continue; for (short j = 0; j < nRightChildSubStates; j++) { if * (oldBScores[i][j]==null) continue; double rcIS = rightChild.getIScore(j); * if (rcIS==0) continue; // allocate parent array if (bcounts[i][j]==null) * bcounts[i][j] = new double[nParentSubStates]; for (short k = 0; k < * nParentSubStates; k++) { double pOS = parent.getOScore(k); // Parent * outside score if (pOS==0) continue; double rS = oldBScores[i][j][k]; if * (rS==0) continue; if (tree_score==0) tree_score = 1; double logRuleCount * = (rS * lcIS / tree_score) * rcIS * scalingFactor * pOS; * * bcounts[i][j][k] += logRuleCount; } } } binaryRuleCounter.setCount(brule, * bcounts); break; default: throw new * Error("Malformed tree: more than two children"); } * * for (Tree<StateSet> child : children) { tallyStateSetTree(child, * tree_score, tree_scale, old_grammar); } } */ /* * private UnaryRule makeUnaryRule(Tree<String> tree) { int parent = * tagNumberer.number(tree.getLabel()); int child = * tagNumberer.number(tree.getChildren().get(0).getLabel()); return new * UnaryRule(parent, child); } * * private BinaryRule makeBinaryRule(Tree<String> tree) { int parent = * tagNumberer.number(tree.getLabel()); int lChild = * tagNumberer.number(tree.getChildren().get(0).getLabel()); int rChild = * tagNumberer.number(tree.getChildren().get(1).getLabel()); return new * BinaryRule(parent, lChild, rChild); } */ public void makeCRArrays() { // int numStates = closedRulesWithParent.length; closedSumRulesWithP = new UnaryRule[numStates][]; closedSumRulesWithC = new UnaryRule[numStates][]; closedViterbiRulesWithP = new UnaryRule[numStates][]; closedViterbiRulesWithC = new UnaryRule[numStates][]; for (int i = 0; i < numStates; i++) { closedSumRulesWithP[i] = (UnaryRule[]) closedSumRulesWithParent[i] .toArray(new UnaryRule[0]); closedSumRulesWithC[i] = (UnaryRule[]) closedSumRulesWithChild[i] .toArray(new UnaryRule[0]); closedViterbiRulesWithP[i] = (UnaryRule[]) closedViterbiRulesWithParent[i] .toArray(new UnaryRule[0]); closedViterbiRulesWithC[i] = (UnaryRule[]) closedViterbiRulesWithChild[i] .toArray(new UnaryRule[0]); } } public UnaryRule[] getClosedSumUnaryRulesByParent(int state) { if (closedSumRulesWithP == null) { makeCRArrays(); } if (state >= closedSumRulesWithP.length) { return new UnaryRule[0]; } return closedSumRulesWithP[state]; } public UnaryRule[] getClosedSumUnaryRulesByChild(int state) { if (closedSumRulesWithC == null) { makeCRArrays(); } if (state >= closedSumRulesWithC.length) { return new UnaryRule[0]; } return closedSumRulesWithC[state]; } public UnaryRule[] getClosedViterbiUnaryRulesByParent(int state) { if (closedViterbiRulesWithP == null) { makeCRArrays(); } if (state >= closedViterbiRulesWithP.length) { return new UnaryRule[0]; } return closedViterbiRulesWithP[state]; } public UnaryRule[] getClosedViterbiUnaryRulesByChild(int state) { if (closedViterbiRulesWithC == null) { makeCRArrays(); } if (state >= closedViterbiRulesWithC.length) { return new UnaryRule[0]; } return closedViterbiRulesWithC[state]; } @SuppressWarnings("unchecked") public void purgeRules() { Map bR = new HashMap(); Map bR2 = new HashMap(); for (Iterator i = bestSumRulesUnderMax.keySet().iterator(); i.hasNext();) { UnaryRule ur = (UnaryRule) i.next(); if ((ur.parentState != ur.childState)) { bR.put(ur, ur); bR2.put(ur, ur); } } bestSumRulesUnderMax = bR; bestViterbiRulesUnderMax = bR2; } @SuppressWarnings("unchecked") public List<short[]> getBestViterbiPath(short pState, short np, short cState, short cp) { ArrayList<short[]> path = new ArrayList<short[]>(); short[] state = new short[2]; state[0] = pState; state[1] = np; // if we haven't built the data structure of closed paths, then // return the simplest possible path if (!findClosedPaths) { path.add(state); state = new short[2]; state[0] = cState; state[1] = cp; path.add(state); return path; } else { // read the best paths off of the closedViterbiPaths list if (pState == cState && np == cp) { path.add(state); path.add(state); return path; } while (state[0] != cState || state[1] != cp) { path.add(state); state[0] = (short) closedViterbiPaths[state[0]][state[1]]; } // add the destination state as well path.add(state); return path; } } @SuppressWarnings("unchecked") private void closeRulesUnderMax(UnaryRule ur) { short pState = ur.parentState; int nPSubStates = numSubStates[pState]; short cState = ur.childState; double[][] uScores = ur.getScores2(); // do all sum rules for (int i = 0; i < closedSumRulesWithChild[pState].size(); i++) { UnaryRule pr = (UnaryRule) closedSumRulesWithChild[pState].get(i); for (int j = 0; j < closedSumRulesWithParent[cState].size(); j++) { short parentState = pr.parentState; int nParentSubStates = numSubStates[parentState]; UnaryRule cr = (UnaryRule) closedSumRulesWithParent[cState] .get(j); UnaryRule resultR = new UnaryRule(parentState, cr .getChildState()); double[][] scores = new double[numSubStates[cr.getChildState()]][nParentSubStates]; for (int np = 0; np < scores[0].length; np++) { for (int cp = 0; cp < scores.length; cp++) { // sum over intermediate substates double sum = 0; for (int unp = 0; unp < nPSubStates; unp++) { for (int ucp = 0; ucp < uScores.length; ucp++) { sum += pr.getScore(np, unp) * cr.getScore(ucp, cp) * ur.getScore(unp, ucp); } } scores[cp][np] = sum; } } resultR.setScores2(scores); // add rule to bestSumRulesUnderMax if it's better relaxSumRule(resultR, pState, cState); } } // do viterbi rules also for (short i = 0; i < closedViterbiRulesWithChild[pState].size(); i++) { UnaryRule pr = (UnaryRule) closedViterbiRulesWithChild[pState] .get(i); for (short j = 0; j < closedViterbiRulesWithParent[cState].size(); j++) { UnaryRule cr = (UnaryRule) closedViterbiRulesWithParent[cState] .get(j); short parentState = pr.parentState; int nParentSubStates = numSubStates[parentState]; UnaryRule resultR = new UnaryRule(parentState, cr .getChildState()); double[][] scores = new double[numSubStates[cr.getChildState()]][nParentSubStates]; short[][] intermediateSubState1 = new short[nParentSubStates][numSubStates[cr .getChildState()]]; short[][] intermediateSubState2 = new short[nParentSubStates][numSubStates[cr .getChildState()]]; for (int np = 0; np < scores[0].length; np++) { for (int cp = 0; cp < scores.length; cp++) { // sum over intermediate substates double max = 0; for (short unp = 0; unp < nPSubStates; unp++) { for (short ucp = 0; ucp < uScores.length; ucp++) { double score = pr.getScore(np, unp) * cr.getScore(ucp, cp) * ur.getScore(unp, ucp); if (score > max) { max = score; intermediateSubState1[np][cp] = unp; intermediateSubState2[np][cp] = ucp; } } } scores[cp][np] = max; } } resultR.setScores2(scores); // add rule to bestSumRulesUnderMax if it's better relaxViterbiRule(resultR, pState, intermediateSubState1, cState, intermediateSubState2); } } } public int getUnaryIntermediate(short start, short end) { return closedSumPaths[start][end]; } @SuppressWarnings("unchecked") private boolean relaxSumRule(UnaryRule ur, int intState1, int intState2) { // TODO: keep track of path UnaryRule bestR = (UnaryRule) bestSumRulesUnderMax.get(ur); if (bestR == null) { bestSumRulesUnderMax.put(ur, ur); closedSumRulesWithParent[ur.parentState].add(ur); closedSumRulesWithChild[ur.childState].add(ur); return true; } else { boolean change = false; for (int i = 0; i < ur.scores[0].length; i++) { for (int j = 0; j < ur.scores.length; j++) { if (bestR.scores[j][i] < ur.scores[j][i]) { bestR.scores[j][i] = ur.scores[j][i]; change = true; } } } return change; } } public void computePairsOfUnaries() { // closedSumRulesWithParent = closedViterbiRulesWithParent = // unaryRulesWithParent; for (short parentState = 0; parentState < numStates; parentState++) { for (short childState = 0; childState < numStates; childState++) { if (parentState == childState) continue; int nParentSubStates = numSubStates[parentState]; int nChildSubStates = numSubStates[childState]; UnaryRule resultRsum = new UnaryRule(parentState, childState); UnaryRule resultRmax = new UnaryRule(parentState, childState); double[][] scoresSum = new double[nChildSubStates][nParentSubStates]; double[][] scoresMax = new double[nChildSubStates][nParentSubStates]; double maxSumScore = -1; short bestSumIntermed = -1; short bestMaxIntermed = -2; for (int i = 0; i < unaryRulesWithParent[parentState].size(); i++) { UnaryRule pr = (UnaryRule) unaryRulesWithParent[parentState] .get(i); short state = pr.getChildState(); if (state == childState) { double total = 0; double[][] scores = pr.getScores2(); for (int cp = 0; cp < nChildSubStates; cp++) { if (scores[cp] == null) continue; for (int np = 0; np < nParentSubStates; np++) { // sum over intermediate substates double sum = scores[cp][np]; scoresSum[cp][np] += sum; total += sum; if (sum > scoresMax[cp][np]) { scoresMax[cp][np] = sum; bestMaxIntermed = -1; } } } if (total > maxSumScore) { bestSumIntermed = -1; maxSumScore = total; } } else { for (int j = 0; j < unaryRulesWithC[childState].size(); j++) { UnaryRule cr = (UnaryRule) unaryRulesWithC[childState] .get(j); if (state != cr.getParentState()) continue; int nMySubStates = numSubStates[state]; double total = 0; for (int np = 0; np < nParentSubStates; np++) { for (int cp = 0; cp < nChildSubStates; cp++) { // sum over intermediate substates double sum = 0; double max = 0; for (int unp = 0; unp < nMySubStates; unp++) { double val = pr.getScore(np, unp) * cr.getScore(unp, cp); sum += val; max = Math.max(max, val); } scoresSum[cp][np] += sum; total += sum; if (max > scoresMax[cp][np]) { scoresMax[cp][np] = max; bestMaxIntermed = state; } } } if (total > maxSumScore) { maxSumScore = total; bestSumIntermed = state; } } } } if (maxSumScore > -1) { resultRsum.setScores2(scoresSum); addUnary(resultRsum); closedSumRulesWithParent[parentState].add(resultRsum); closedSumRulesWithChild[childState].add(resultRsum); closedSumPaths[parentState][childState] = bestSumIntermed; } if (bestMaxIntermed > -2) { resultRmax.setScores2(scoresMax); // addUnary(resultR); closedViterbiRulesWithParent[parentState].add(resultRmax); closedViterbiRulesWithChild[childState].add(resultRmax); closedViterbiPaths[parentState][childState] = bestMaxIntermed; /* * if (bestMaxIntermed > -1){ * System.out.println("NEW RULE CREATED"); } */ } } } } /* * @SuppressWarnings("unchecked") private boolean relaxSumRule(UnaryRule * rule) { bestSumRulesUnderMax.put(rule, rule); * closedSumRulesWithParent[rule.parentState].add(rule); * closedSumRulesWithChild[rule.childState].add(rule); return true; } */ /** * Update the best unary chain probabilities and paths with this new rule. * * @param ur * @param subStates1 * @param subStates2 * @return */ @SuppressWarnings("unchecked") private void relaxViterbiRule(UnaryRule ur, short intState1, short[][] intSubStates1, short intState2, short[][] intSubStates2) { throw new Error("Viterbi closure is broken!"); /* * UnaryRule bestR = (UnaryRule) bestViterbiRulesUnderMax.get(ur); * boolean isNewRule = (bestR==null); if (isNewRule) { * bestViterbiRulesUnderMax.put(ur, ur); * closedViterbiRulesWithParent[ur.parentState].add(ur); * closedViterbiRulesWithChild[ur.childState].add(ur); bestR = ur; } for * (int i=0; i<ur.scores[0].length; i++) { for (int j=0; * j<ur.scores.length; j++) { if (isNewRule || bestR.scores[j][i] < * ur.scores[j][i]) { bestR.scores[j][i] = ur.scores[j][i]; // update * best path information if (findClosedPaths) { short[] intermediate = * null; if (ur.parentState==intState1 && intSubStates1[i][j]==i) { * intermediate = new short[2]; intermediate[0] = intState2; * intermediate[1] = intSubStates2[i][j]; } else { //intermediate = * closedViterbiPaths * [ur.parentState][intState1][i][intSubStates1[i][j]]; } if * (closedViterbiPaths[ur.parentState][ur.childState]==null) { * closedViterbiPaths[ur.parentState][ur.childState] = new * short[numSubStates[ur.parentState]][numSubStates[ur.childState]][]; } * closedViterbiPaths[ur.parentState][ur.childState][i][j] = * intermediate; } } } } */} /** * Initialize the best unary chain probabilities and paths with this rule. * * @param rule */ @SuppressWarnings("unchecked") private void relaxViterbiRule(UnaryRule rule) { bestViterbiRulesUnderMax.put(rule, rule); closedViterbiRulesWithParent[rule.parentState].add(rule); closedViterbiRulesWithChild[rule.childState].add(rule); if (findClosedPaths) { for (short i = 0; i < rule.scores.length; i++) { for (short j = 0; j < rule.scores[i].length; j++) { short[] pair = new short[2]; pair[0] = rule.childState; pair[1] = j; /* * if * (closedViterbiPaths[rule.parentState][rule.childState]== * null) { * closedViterbiPaths[rule.parentState][rule.childState] = * new short[rule.scores.length][rule.scores[0].length][]; } * closedViterbiPaths * [rule.parentState][rule.childState][i][j] = pair; */ } } } } /** * 'parentRules', 'childRules' and the return value all have the same format * as 'unaryRulesWithParent', but can be thought of as square matrices. All * this function does is a matrix multiplication, but operating directly on * this non-standard matrix representation. 'parentRules' gives the * probability of going from A to B, 'childRules' from B to C, and the * return value from A to C (summing out B). This function is intended * primarily to compute unaryRulesWithParent^n. */ private List<UnaryRule>[] matrixMultiply(List<UnaryRule>[] parentRules, List<UnaryRule>[] childRules) { throw new Error("I'm broken by parent first"); /* * double[][][][] scores = new double[numStates][numStates][][]; for ( * short A=0; A<numStates; A++ ) { for ( UnaryRule rAB : parentRules[A] * ) { short B = rAB.childState; double[][] scoresAB = rAB.getScores(); * for ( UnaryRule rBC : childRules[B] ) { short C = rBC.childState; if * ( scores[A][C] == null ) { scores[A][C] = new * double[numSubStates[A]][numSubStates[C]]; * ArrayUtil.fill(scores[A][C], Double.NEGATIVE_INFINITY); } double[][] * scoresBC = rBC.getScores(); double[] scoresToAdd = new * double[numSubStates[B]+1]; for ( int a = 0; a < numSubStates[A]; a++ * ) { for ( int c = 0; c < numSubStates[C]; c++ ) { // * Arrays.fill(scoresToAdd, Double.NEGATIVE_INFINITY); // No need to * here scoresToAdd[scoresToAdd.length-1] = scores[A][C][a][c]; // The * current score to which to add the new contributions for ( int b = 0; * b < numSubStates[B]; b++ ) { scoresToAdd[b] = scoresAB[a][b] + * scoresBC[b][c]; } scores[A][C][a][c] = * SloppyMath.logAdd(scoresToAdd); } } } } } * * @SuppressWarnings("unchecked") List<UnaryRule>[] result = new * List[numStates]; for ( short A=0; A<numStates; A++ ) { result[A] = * new ArrayList<UnaryRule>(); for ( short C=0; C<numStates; C++ ) { if * ( scores[A][C] != null ) { result[A].add(new * UnaryRule(A,C,scores[A][C])); } } } return result; */ } /** * rules1 += rules2 (adds rules2 into rules1, destroying rules1) No sharing * of score arrays occurs because of this operation since rules2 data is * either added in or copied. * * @param rules1 * @param rules2 */ private void matrixAdd(List<UnaryRule>[] rules1, List<UnaryRule>[] rules2) { throw new Error("I'm broken by parent first"); /* * for ( short A=0; A<numStates; A++ ) { for ( UnaryRule r2 : rules2[A] * ) { short child2 = r2.getChildState(); double[][] scores2 = * r2.getScores(); boolean matchFound = false; for ( UnaryRule r1 : * rules1[A] ) { short child1 = r1.getChildState(); if ( child1 == * child2 ) { double[][] scores1 = r1.getScores(); for ( int a = 0; a < * numSubStates[A]; a++ ) { for ( int c = 0; c < numSubStates[child1]; * c++ ) { scores1[a][c] = SloppyMath.logAdd(scores1[a][c], * scores2[a][c]); } } matchFound = true; break; } } if (!matchFound) { * // Make a (deep) copy of rule r2 UnaryRule ruleCopy = new * UnaryRule(r2); double[][] scoresCopy = new * double[numSubStates[A]][numSubStates[child2]]; for ( int a = 0; a < * numSubStates[A]; a++ ) { for ( int c = 0; c < numSubStates[child2]; * c++ ) { scoresCopy[a][c] = scores2[a][c]; } } * ruleCopy.setScores(scoresCopy); rules1[A].add(ruleCopy); } } } */ } private List<UnaryRule>[] matrixUnity() { throw new Error("I'm broken by parent first"); // List<UnaryRule>[] result = new List[numStates]; // for ( short A=0; A<numStates; A++ ) { // result[A] = new ArrayList<UnaryRule>(); // double[][] scores = new double[numSubStates[A]][numSubStates[A]]; // ArrayUtil.fill(scores, Double.NEGATIVE_INFINITY); // for ( int a = 0; a < numSubStates[A]; a++ ) { // scores[a][a] = 0; // } // UnaryRule rule = new UnaryRule(A, A, scores); // result[A].add(rule); // } // return result; } /** * @param P * @return I + P + P^2 + P^3 + ... (approximation by truncation after some * power) */ private List<UnaryRule>[] sumProductUnaryClosure(List<UnaryRule>[] P) { throw new Error("I'm broken by parent first"); /* * List<UnaryRule>[] R = matrixUnity(); matrixAdd(R, P); // R = I + P + * P^2 + P^3 + ... List<UnaryRule>[] Q = P; // Q = P^k int maxPower = 3; * for ( int i = 1; i < maxPower; i++ ) { Q = matrixMultiply(Q, P); * matrixAdd(R, Q); } return R; */ } /** * Assumption: A in possibleSt ==> V[A] != null. This property is true of * the result as well. The converse is not true because of a workaround for * part of speech tags that we must handle here. * * @param V * (considered a row vector, indexed by (state, substate)) * @param M * (a matrix represented in List<UnaryRule>[] (by parent) format) * @param possibleSt * (a list of possible states to consider) * @return U=V*M (row vector) */ public double[][] matrixVectorPreMultiply(double[][] V, List<UnaryRule>[] M, List<Integer> possibleSt) { throw new Error("I'm broken by parent first"); /* * double[][] U = new double[numStates][]; for (int pState : * possibleSt){ U[pState] = new double[numSubStates[pState]]; * Arrays.fill(U[pState], Double.NEGATIVE_INFINITY); UnaryRule[] unaries * = M[pState].toArray(new UnaryRule[0]); for ( UnaryRule ur : unaries ) * { int cState = ur.childState; if ( V[cState] == null ) { continue; } * double[][] scores = ur.getScores(); // numSubStates[pState] * * numSubStates[cState] int nParentStates = numSubStates[pState]; int * nChildStates = numSubStates[cState]; double[] termsToAdd = new * double[nChildStates+1]; // Could be inside the for(np) loop for (int * np = 0; np < nParentStates; np++) { Arrays.fill(termsToAdd, * Double.NEGATIVE_INFINITY); double currentVal = U[pState][np]; * termsToAdd[termsToAdd.length-1] = currentVal; for (int cp = 0; cp < * nChildStates; cp++) { double iS = V[cState][cp]; if (iS == * Double.NEGATIVE_INFINITY) { continue; } double pS = scores[np][cp]; * termsToAdd[cp] = iS + pS; } * * double newVal = SloppyMath.logAdd(termsToAdd); if (newVal > * currentVal) { U[pState][np] = newVal; } } } } return U; */ } /** * Assumption: A in possibleSt ==> V[A] != null. This property is true of * the result as well. The converse is not true because of a workaround for * part of speech tags that we must handle here. * * @param M * (a matrix represented in List<UnaryRule>[] (by parent) format) * @param V * (considered a column vector, indexed by (state, substate)) * @param possibleSt * (a list of possible states to consider) * @return U=M*V (column vector) */ public double[][] matrixVectorPostMultiply(List<UnaryRule>[] M, double[][] V, List<Integer> possibleSt) { throw new Error("I'm broken by parent first"); /* * double[][] U = new double[numStates][]; for (int cState : * possibleSt){ U[cState] = new double[numSubStates[cState]]; * Arrays.fill(U[cState], Double.NEGATIVE_INFINITY); } for (int pState : * possibleSt){ UnaryRule[] unaries = M[pState].toArray(new * UnaryRule[0]); for ( UnaryRule ur : unaries ) { int cState = * ur.childState; if ( U[cState] == null ) { continue; } double[][] * scores = ur.getScores(); // numSubStates[pState] * * numSubStates[cState] int nParentStates = numSubStates[pState]; int * nChildStates = numSubStates[cState]; double[] termsToAdd = new * double[nParentStates+1]; // Could be inside the for(np) loop for (int * cp = 0; cp < nChildStates; cp++) { Arrays.fill(termsToAdd, * Double.NEGATIVE_INFINITY); double currentVal = U[cState][cp]; * termsToAdd[termsToAdd.length-1] = currentVal; for (int np = 0; np < * nParentStates; np++) { double oS = V[pState][np]; if (oS == * Double.NEGATIVE_INFINITY) { continue; } double pS = scores[np][cp]; * termsToAdd[cp] = oS + pS; } * * double newVal = SloppyMath.logAdd(termsToAdd); if (newVal > * currentVal) { U[cState][cp] = newVal; } } } } return U; */ } /** * Populates the "splitRules" accessor lists using the existing rule lists. * If the state is synthetic, these lists contain all rules for the state. * If the state is NOT synthetic, these lists contain only the rules in * which both children are not synthetic. * <p> * <i>This method must be called before the grammar is used, either after * training or deserializing grammar.</i> */ @SuppressWarnings("unchecked") public void splitRules() { // splitRulesWithLC = new BinaryRule[numStates][]; // splitRulesWithRC = new BinaryRule[numStates][]; // makeRulesAccessibleByChild(); if (binaryRulesWithParent == null) return; splitRulesWithP = new BinaryRule[numStates][]; splitRulesWithLC = new BinaryRule[numStates][]; splitRulesWithRC = new BinaryRule[numStates][]; for (int state = 0; state < numStates; state++) { splitRulesWithLC[state] = toBRArray(binaryRulesWithLC[state]); splitRulesWithRC[state] = toBRArray(binaryRulesWithRC[state]); splitRulesWithP[state] = toBRArray(binaryRulesWithParent[state]); } // we don't need the original lists anymore binaryRulesWithParent = null; binaryRulesWithLC = null; binaryRulesWithRC = null; makeCRArrays(); } public BinaryRule[] splitRulesWithLC(int state) { // System.out.println("splitRulesWithLC not supported anymore."); // return null; if (state >= splitRulesWithLC.length) { return new BinaryRule[0]; } return splitRulesWithLC[state]; } public BinaryRule[] splitRulesWithRC(int state) { // System.out.println("splitRulesWithLC not supported anymore."); // return null; if (state >= splitRulesWithRC.length) { return new BinaryRule[0]; } return splitRulesWithRC[state]; } public BinaryRule[] splitRulesWithP(int state) { if (splitRulesWithP == null) splitRules(); if (state >= splitRulesWithP.length) { return new BinaryRule[0]; } return splitRulesWithP[state]; } private BinaryRule[] toBRArray(List<BinaryRule> list) { // Collections.sort(list, Rule.scoreComparator()); // didn't seem to // help BinaryRule[] array = new BinaryRule[list.size()]; for (int i = 0; i < array.length; i++) { array[i] = list.get(i); } return array; } public double[][] getUnaryScore(short pState, short cState) { UnaryRule r = getUnaryRule(pState, cState); if (r != null) return r.getScores2(); if (GrammarTrainer.VERBOSE) System.out.println("The requested rule (" + uSearchRule + ") is not in the grammar!"); double[][] uscores = new double[numSubStates[cState]][numSubStates[pState]]; ArrayUtil.fill(uscores, 1.0); return uscores; } /** * @param pState * @param cState * @return */ public UnaryRule getUnaryRule(short pState, short cState) { UnaryRule uRule = new UnaryRule(pState, cState); UnaryRule r = unaryRuleMap.get(uRule); return r; } public double[][] getUnaryScore(UnaryRule rule) { UnaryRule r = unaryRuleMap.get(rule); if (r != null) return r.getScores2(); if (GrammarTrainer.VERBOSE) System.out.println("The requested rule (" + rule + ") is not in the grammar!"); double[][] uscores = new double[numSubStates[rule.getChildState()]][numSubStates[rule .getParentState()]]; ArrayUtil.fill(uscores, 1.0); return uscores; } public double[][][] getBinaryScore(short pState, short lState, short rState) { BinaryRule r = getBinaryRule(pState, lState, rState); if (r != null) return r.getScores2(); if (GrammarTrainer.VERBOSE) System.out.println("The requested rule (" + bSearchRule + ") is not in the grammar!"); double[][][] bscores = new double[numSubStates[lState]][numSubStates[rState]][numSubStates[pState]]; ArrayUtil.fill(bscores, 1.0); return bscores; } /** * @param pState * @param lState * @param rState * @return */ public BinaryRule getBinaryRule(short pState, short lState, short rState) { BinaryRule bRule = new BinaryRule(pState, lState, rState); BinaryRule r = binaryRuleMap.get(bRule); return r; } public double[][][] getBinaryScore(BinaryRule rule) { BinaryRule r = binaryRuleMap.get(rule); if (r != null) return r.getScores2(); else { if (GrammarTrainer.VERBOSE) System.out.println("The requested rule (" + rule + ") is not in the grammar!"); double[][][] bscores = new double[numSubStates[rule .getLeftChildState()]][numSubStates[rule .getRightChildState()]][numSubStates[rule.getParentState()]]; ArrayUtil.fill(bscores, 1.0); return bscores; } } public void printSymbolCounter(Numberer tagNumberer) { Set<Integer> set = symbolCounter.keySet(); PriorityQueue<String> pq = new PriorityQueue<String>(set.size()); for (Integer i : set) { pq .add((String) tagNumberer.object(i), symbolCounter .getCount(i, 0)); // System.out.println(i+". "+(String)tagNumberer.object(i)+"\t // "+symbolCounter.getCount(i,0)); } int i = 0; while (pq.hasNext()) { i++; int p = (int) pq.getPriority(); System.out.println(i + ". " + pq.next() + "\t " + p); } } public int getSymbolCount(Integer i) { return (int) symbolCounter.getCount(i, 0); } private void makeRulesAccessibleByChild() { // first the binaries if (true) return; for (int state = 0; state < numStates; state++) { if (!isGrammarTag[state]) continue; if (binaryRulesWithParent == null) continue; for (BinaryRule rule : binaryRulesWithParent[state]) { binaryRulesWithLC[rule.leftChildState].add(rule); binaryRulesWithRC[rule.rightChildState].add(rule); } // for (UnaryRule rule : unaryRulesWithParent[state]){ // unaryRulesWithC[rule.childState].add(rule); // } } } /** * Split all substates into two new ones. This produces a new Grammar with * updated rules. * * @param randomness * percent randomness applied in splitting rules * @param mode * 0: normalized (at least almost) 1: not normalized (when * splitting a log-linear grammar) 2: just noise (for log-linear * grammars with cascading regularization) * @return */ public Grammar splitAllStates(double randomness, int[] counts, boolean moreSubstatesThanCounts, int mode) { if (logarithmMode) { throw new Error( "Do not split states when Grammar is in logarithm mode"); } short[] newNumSubStates = new short[numSubStates.length]; for (short i = 0; i < numSubStates.length; i++) { // don't split a state into more substates than times it was // actaully seen // if (!moreSubstatesThanCounts && numSubStates[i]>=counts[i]) { // newNumSubStates[i]=numSubStates[i]; // } // else{ newNumSubStates[i] = (short) (numSubStates[i] * 2); // } } boolean doNotNormalize = (mode == 1); newNumSubStates[0] = 1; // never split ROOT // create the new grammar Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); Random random = GrammarTrainer.RANDOM; for (BinaryRule oldRule : binaryRuleMap.keySet()) { BinaryRule newRule = oldRule.splitRule(numSubStates, newNumSubStates, random, randomness, doNotNormalize, mode); grammar.addBinary(newRule); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { UnaryRule newRule = oldRule.splitRule(numSubStates, newNumSubStates, random, randomness, doNotNormalize, mode); grammar.addUnary(newRule); } grammar.isGrammarTag = this.isGrammarTag; grammar.extendSplitTrees(splitTrees, numSubStates); grammar.computePairsOfUnaries(); return grammar; } @SuppressWarnings("unchecked") public void extendSplitTrees(Tree<Short>[] trees, short[] oldNumSubStates) { this.splitTrees = new Tree[numStates]; for (int tag = 0; tag < splitTrees.length; tag++) { Tree<Short> splitTree = trees[tag].shallowClone(); for (Tree<Short> leaf : splitTree.getTerminals()) { List<Tree<Short>> children = leaf.getChildren(); if (numSubStates[tag] > oldNumSubStates[tag]) { children .add(new Tree<Short>((short) (2 * leaf.getLabel()))); children.add(new Tree<Short>( (short) (2 * leaf.getLabel() + 1))); } else { children.add(new Tree<Short>(leaf.getLabel())); } } this.splitTrees[tag] = splitTree; } } public int totalSubStates() { int count = 0; for (int i = 0; i < numStates; i++) { count += numSubStates[i]; } return count; } /** * Tally the probability of seeing each substate. This data is needed for * tallyMergeScores. mergeWeights is indexed as [state][substate]. This data * should be normalized before being used by another function. * * @param tree * @param mergeWeights * The probability of seeing substate given state. */ public void tallyMergeWeights(Tree<StateSet> tree, double mergeWeights[][]) { if (tree.isLeaf()) return; StateSet label = tree.getLabel(); short state = label.getState(); double probs[] = new double[label.numSubStates()]; double total = 0, tmp; for (short i = 0; i < label.numSubStates(); i++) { tmp = label.getIScore(i) * label.getOScore(i); // TODO: put in the scale parameters??? probs[i] = tmp; total += tmp; } if (total == 0) total = 1; for (short i = 0; i < label.numSubStates(); i++) { mergeWeights[state][i] += probs[i] / total; } for (Tree<StateSet> child : tree.getChildren()) { tallyMergeWeights(child, mergeWeights); } } /* * normalize merge weights. assumes that the mergeWeights are given as logs. * the normalized weights are returned as probabilities. */ public void normalizeMergeWeights(double[][] mergeWeights) { for (int state = 0; state < mergeWeights.length; state++) { double sum = 0; for (int subState = 0; subState < numSubStates[state]; subState++) { sum += mergeWeights[state][subState]; } if (sum == 0) sum = 1; for (int subState = 0; subState < numSubStates[state]; subState++) { mergeWeights[state][subState] /= sum; } } } /** * Calculate the log likelihood gain of merging pairs of split states * together. This information is returned in deltas[state][merged substate]. * It requires mergeWeights to be calculated by tallyMergeWeights. * * @param tree * @param deltas * The log likelihood gained by merging pairs of substates. * @param mergeWeights * The probability of seeing substate given state. */ public void tallyMergeScores(Tree<StateSet> tree, double[][][] deltas, double[][] mergeWeights) { if (tree.isLeaf()) return; StateSet label = tree.getLabel(); short state = label.getState(); double[] separatedScores = new double[label.numSubStates()]; double[] combinedScores = new double[label.numSubStates()]; double combinedScore; // calculate separated scores double separatedScoreSum = 0, tmp; // don't need to deal with scale factor because we divide below for (int i = 0; i < label.numSubStates(); i++) { tmp = label.getIScore(i) * label.getOScore(i); combinedScores[i] = separatedScores[i] = tmp; separatedScoreSum += tmp; } // calculate merged scores for (short i = 0; i < numSubStates[state]; i++) { for (short j = (short) (i + 1); j < numSubStates[state]; j++) { short[] map = new short[2]; map[0] = i; map[1] = j; double[] tmp1 = new double[2], tmp2 = new double[2]; double mergeWeightSum = 0; for (int k = 0; k < 2; k++) { mergeWeightSum += mergeWeights[state][map[k]]; } if (mergeWeightSum == 0) mergeWeightSum = 1; for (int k = 0; k < 2; k++) { tmp1[k] = label.getIScore(map[k]) * mergeWeights[state][map[k]] / mergeWeightSum; tmp2[k] = label.getOScore(map[k]); } combinedScore = (tmp1[0] + tmp1[1]) * (tmp2[0] + tmp2[1]); combinedScores[i] = combinedScore; combinedScores[j] = 0; if (combinedScore != 0 && separatedScoreSum != 0) deltas[state][i][j] += Math.log(separatedScoreSum / ArrayUtil.sum(combinedScores)); for (int k = 0; k < 2; k++) combinedScores[map[k]] = separatedScores[map[k]]; if (Double.isNaN(deltas[state][i][j])) { System.out.println(" deltas[" + tagNumberer.object(state) + "][" + i + "][" + j + "] = NaN"); System.out.println(Arrays.toString(separatedScores) + " " + Arrays.toString(tmp1) + " " + Arrays.toString(tmp2) + " " + combinedScore + " " + Arrays.toString(mergeWeights[state])); } } } for (Tree<StateSet> child : tree.getChildren()) { tallyMergeScores(child, deltas, mergeWeights); } } /** * This merges the substate pairs indicated by * mergeThesePairs[state][substate pair]. It requires merge weights * calculated by tallyMergeWeights. * * @param mergeThesePairs * Which substate pairs to merge. * @param mergeWeights * The probability of seeing each substate. */ public Grammar mergeStates(boolean[][][] mergeThesePairs, double[][] mergeWeights) { if (logarithmMode) { throw new Error("Do not merge grammars in logarithm mode!"); } short[] newNumSubStates = new short[numSubStates.length]; short[][] mapping = new short[numSubStates.length][]; // invariant: if partners[state][substate][0] == substate, it's the 1st // one short[][][] partners = new short[numSubStates.length][][]; calculateMergeArrays(mergeThesePairs, newNumSubStates, mapping, partners, numSubStates); // create the new grammar Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); // for (Rule r : allRules) { // if (r instanceof BinaryRule) { for (BinaryRule oldRule : binaryRuleMap.keySet()) { // BinaryRule oldRule = r; short pS = oldRule.getParentState(), lcS = oldRule .getLeftChildState(), rcS = oldRule.getRightChildState(); double[][][] oldScores = oldRule.getScores2(); // merge binary rule double[][][] newScores = new double[newNumSubStates[lcS]][newNumSubStates[rcS]][newNumSubStates[pS]]; for (int i = 0; i < numSubStates[pS]; i++) { if (partners[pS][i][0] == i) { int parentSplit = partners[pS][i].length; for (int j = 0; j < numSubStates[lcS]; j++) { if (partners[lcS][j][0] == j) { int leftSplit = partners[lcS][j].length; for (int k = 0; k < (numSubStates[rcS]); k++) { if (partners[rcS][k][0] == k) { int rightSplit = partners[rcS][k].length; double[][][] scores = new double[leftSplit][rightSplit][parentSplit]; for (int js = 0; js < leftSplit; js++) { for (int ks = 0; ks < rightSplit; ks++) { if (oldScores[partners[lcS][j][js]][partners[rcS][k][ks]] == null) continue; for (int is = 0; is < parentSplit; is++) { scores[js][ks][is] = oldScores[partners[lcS][j][js]][partners[rcS][k][ks]][partners[pS][i][is]]; } } } if (rightSplit == 2) { for (int is = 0; is < parentSplit; is++) { for (int js = 0; js < leftSplit; js++) { scores[js][0][is] = scores[js][1][is] = scores[js][0][is] + scores[js][1][is]; } } } if (leftSplit == 2) { for (int is = 0; is < parentSplit; is++) { for (int ks = 0; ks < rightSplit; ks++) { scores[0][ks][is] = scores[1][ks][is] = scores[0][ks][is] + scores[1][ks][is]; } } } if (parentSplit == 2) { for (int js = 0; js < leftSplit; js++) { for (int ks = 0; ks < rightSplit; ks++) { double mergeWeightSum = mergeWeights[pS][partners[pS][i][0]] + mergeWeights[pS][partners[pS][i][1]]; if (SloppyMath .isDangerous(mergeWeightSum)) mergeWeightSum = 1; scores[js][ks][0] = scores[js][ks][1] = ((scores[js][ks][0] * mergeWeights[pS][partners[pS][i][0]]) + (scores[js][ks][1] * mergeWeights[pS][partners[pS][i][1]])) / mergeWeightSum; } } } for (int is = 0; is < parentSplit; is++) { for (int js = 0; js < leftSplit; js++) { for (int ks = 0; ks < rightSplit; ks++) { newScores[mapping[lcS][partners[lcS][j][js]]][mapping[rcS][partners[rcS][k][ks]]][mapping[pS][partners[pS][i][is]]] = scores[js][ks][is]; } } } } } } } } } BinaryRule newRule = new BinaryRule(oldRule); newRule.setScores2(newScores); grammar.addBinary(newRule); } // } else if (r instanceof UnaryRule) { for (UnaryRule oldRule : unaryRuleMap.keySet()) { // UnaryRule oldRule = (UnaryRule) r; short pS = oldRule.getParentState(), cS = oldRule.getChildState(); // merge unary rule double[][] newScores = new double[newNumSubStates[cS]][newNumSubStates[pS]]; double[][] oldScores = oldRule.getScores2(); boolean allZero = true; for (int i = 0; i < numSubStates[pS]; i++) { if (partners[pS][i][0] == i) { int parentSplit = partners[pS][i].length; for (int j = 0; j < numSubStates[cS]; j++) { if (partners[cS][j][0] == j) { int childSplit = partners[cS][j].length; double[][] scores = new double[childSplit][parentSplit]; for (int js = 0; js < childSplit; js++) { if (oldScores[partners[cS][j][js]] == null) continue; for (int is = 0; is < parentSplit; is++) { scores[js][is] = oldScores[partners[cS][j][js]][partners[pS][i][is]]; } } if (childSplit == 2) { for (int is = 0; is < parentSplit; is++) { scores[0][is] = scores[1][is] = scores[0][is] + scores[1][is]; } } if (parentSplit == 2) { for (int js = 0; js < childSplit; js++) { double mergeWeightSum = mergeWeights[pS][partners[pS][i][0]] + mergeWeights[pS][partners[pS][i][1]]; if (SloppyMath.isDangerous(mergeWeightSum)) mergeWeightSum = 1; scores[js][0] = scores[js][1] = ((scores[js][0] * mergeWeights[pS][partners[pS][i][0]]) + (scores[js][1] * mergeWeights[pS][partners[pS][i][1]])) / mergeWeightSum; } } for (int is = 0; is < parentSplit; is++) { for (int js = 0; js < childSplit; js++) { newScores[mapping[cS][partners[cS][j][js]]][mapping[pS][partners[pS][i][is]]] = scores[js][is]; allZero = allZero && (scores[js][is] == 0); } } } } } } // if (allZero){ // System.out.println("Maybe an underflow? Rule: "+oldRule); // System.out.println(ArrayUtil.toString(newScores)); // System.out.println(ArrayUtil.toString(oldScores)); // System.out.println(Arrays.toString(mergeWeights[pS])); // } UnaryRule newRule = new UnaryRule(oldRule); newRule.setScores2(newScores); grammar.addUnary(newRule); } grammar.pruneSplitTree(partners, mapping); grammar.isGrammarTag = this.isGrammarTag; grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent; grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC; return grammar; } /** * @param mergeThesePairs * @param partners */ private void pruneSplitTree(short[][][] partners, short[][] mapping) { for (int tag = 0; tag < splitTrees.length; tag++) { Tree<Short> splitTree = splitTrees[tag]; int maxDepth = splitTree.getDepth(); for (Tree<Short> preTerminal : splitTree.getAtDepth(maxDepth - 2)) { List<Tree<Short>> children = preTerminal.getChildren(); ArrayList<Tree<Short>> newChildren = new ArrayList<Tree<Short>>( 2); for (int i = 0; i < children.size(); i++) { Tree<Short> child = children.get(i); int curLoc = child.getLabel(); if (partners[tag][curLoc][0] == curLoc) { newChildren.add(new Tree<Short>(mapping[tag][curLoc])); } } preTerminal.setChildren(newChildren); } } } public static void checkNormalization(Grammar grammar) { double[][] psum = new double[grammar.numSubStates.length][]; for (int pS = 0; pS < grammar.numSubStates.length; pS++) { psum[pS] = new double[grammar.numSubStates[pS]]; } boolean[] sawPS = new boolean[grammar.numSubStates.length]; for (UnaryRule ur : grammar.unaryRuleMap.values()) { int pS = ur.getParentState(); sawPS[pS] = true; int cS = ur.getChildState(); double[][] scores = ur.getScores2(); for (int ci = 0; ci < grammar.numSubStates[cS]; ci++) { if (scores[ci] == null) continue; for (int pi = 0; pi < grammar.numSubStates[pS]; pi++) { psum[pS][pi] += scores[ci][pi]; } } } for (BinaryRule br : grammar.binaryRuleMap.values()) { int pS = br.getParentState(); sawPS[pS] = true; int lcS = br.getLeftChildState(); int rcS = br.getRightChildState(); double[][][] scores = br.getScores2(); for (int lci = 0; lci < grammar.numSubStates[lcS]; lci++) { for (int rci = 0; rci < grammar.numSubStates[rcS]; rci++) { if (scores[lci][rci] == null) continue; for (int pi = 0; pi < grammar.numSubStates[pS]; pi++) { psum[pS][pi] += scores[lci][rci][pi]; } } } } System.out.println(); System.out.println("Checking for substates whose probs don't sum to 1"); for (int pS = 0; pS < grammar.numSubStates.length; pS++) { if (!sawPS[pS]) continue; for (int pi = 0; pi < grammar.numSubStates[pS]; pi++) { if (Math.abs(1 - psum[pS][pi]) > 0.001) System.out.println(" state " + pS + " substate " + pi + " gives bad psum: " + psum[pS][pi]); } } } /** * @param mergeThesePairs * @param newNumSubStates * @param mapping * @param partners */ public static void calculateMergeArrays(boolean[][][] mergeThesePairs, short[] newNumSubStates, short[][] mapping, short[][][] partners, short[] numSubStates) { for (short state = 0; state < numSubStates.length; state++) { short mergeTarget[] = new short[mergeThesePairs[state].length]; Arrays.fill(mergeTarget, (short) -1); short count = 0; mapping[state] = new short[numSubStates[state]]; partners[state] = new short[numSubStates[state]][]; for (short j = 0; j < numSubStates[state]; j++) { if (mergeTarget[j] != -1) { mapping[state][j] = mergeTarget[j]; } else { partners[state][j] = new short[1]; partners[state][j][0] = j; mapping[state][j] = count; count++; // assume we're only merging pairs, so we only see things to // merge // with this substate when this substate isn't being merged // with anything // earlier for (short k = (short) (j + 1); k < numSubStates[state]; k++) { if (mergeThesePairs[state][j][k]) { mergeTarget[k] = mapping[state][j]; partners[state][j] = new short[2]; partners[state][j][0] = j; partners[state][j][1] = k; partners[state][k] = partners[state][j]; } } } } newNumSubStates[state] = count; } newNumSubStates[0] = 1; // never split or merge ROOT } public void fixMergeWeightsEtc(boolean[][][] mergeThesePairs, double[][] mergeWeights, boolean[][][] complexMergePairs) { short[] newNumSubStates = new short[numSubStates.length]; short[][] mapping = new short[numSubStates.length][]; // invariant: if partners[state][substate][0] == substate, it's the 1st // one short[][][] partners = new short[numSubStates.length][][]; calculateMergeArrays(mergeThesePairs, newNumSubStates, mapping, partners, numSubStates); for (int tag = 0; tag < numSubStates.length; tag++) { double[] newMergeWeights = new double[newNumSubStates[tag]]; for (int i = 0; i < numSubStates[tag]; i++) { newMergeWeights[mapping[tag][i]] += mergeWeights[tag][i]; } mergeWeights[tag] = newMergeWeights; boolean[][] newComplexMergePairs = new boolean[newNumSubStates[tag]][newNumSubStates[tag]]; boolean[][] newMergeThesePairs = new boolean[newNumSubStates[tag]][newNumSubStates[tag]]; for (int i = 0; i < complexMergePairs[tag].length; i++) { for (int j = 0; j < complexMergePairs[tag].length; j++) { newComplexMergePairs[mapping[tag][i]][mapping[tag][j]] = newComplexMergePairs[mapping[tag][i]][mapping[tag][j]] || complexMergePairs[tag][i][j]; newMergeThesePairs[mapping[tag][i]][mapping[tag][j]] = newMergeThesePairs[mapping[tag][i]][mapping[tag][j]] || mergeThesePairs[tag][i][j]; } } complexMergePairs[tag] = newComplexMergePairs; mergeThesePairs[tag] = newMergeThesePairs; } } public void logarithmMode() { // System.out.println("The gramar is in logarithmMode!"); if (logarithmMode) return; logarithmMode = true; for (UnaryRule r : unaryRuleMap.keySet()) { logarithmModeRule(unaryRuleMap.get(r)); } for (BinaryRule r : binaryRuleMap.keySet()) { logarithmModeRule(binaryRuleMap.get(r)); } // Leon thinks the following sets of rules are already covered above, // but he wants to take no chances logarithmModeBRuleListArray(binaryRulesWithParent); logarithmModeBRuleListArray(binaryRulesWithLC); logarithmModeBRuleListArray(binaryRulesWithRC); logarithmModeBRuleArrayArray(splitRulesWithLC); logarithmModeBRuleArrayArray(splitRulesWithRC); logarithmModeBRuleArrayArray(splitRulesWithP); logarithmModeURuleListArray(unaryRulesWithParent); logarithmModeURuleListArray(unaryRulesWithC); logarithmModeURuleListArray(sumProductClosedUnaryRulesWithParent); logarithmModeURuleListArray(closedSumRulesWithParent); logarithmModeURuleListArray(closedSumRulesWithChild); logarithmModeURuleListArray(closedViterbiRulesWithParent); logarithmModeURuleListArray(closedViterbiRulesWithChild); logarithmModeURuleArrayArray(closedSumRulesWithP); logarithmModeURuleArrayArray(closedSumRulesWithC); logarithmModeURuleArrayArray(closedViterbiRulesWithP); logarithmModeURuleArrayArray(closedViterbiRulesWithC); } /** * */ private void logarithmModeBRuleListArray(List<BinaryRule>[] a) { if (a != null) { for (List<BinaryRule> l : a) { if (l == null) continue; for (BinaryRule r : l) { logarithmModeRule(r); } } } } /** * */ private void logarithmModeURuleListArray(List<UnaryRule>[] a) { if (a != null) { for (List<UnaryRule> l : a) { if (l == null) continue; for (UnaryRule r : l) { logarithmModeRule(r); } } } } /** * */ private void logarithmModeBRuleArrayArray(BinaryRule[][] a) { if (a != null) { for (BinaryRule[] l : a) { if (l == null) continue; for (BinaryRule r : l) { logarithmModeRule(r); } } } } /** * */ private void logarithmModeURuleArrayArray(UnaryRule[][] a) { if (a != null) { for (UnaryRule[] l : a) { if (l == null) continue; for (UnaryRule r : l) { logarithmModeRule(r); } } } } /** * @param r */ private static void logarithmModeRule(BinaryRule r) { if (r == null || r.logarithmMode) return; r.logarithmMode = true; double[][][] scores = r.getScores2(); for (int i = 0; i < scores.length; i++) { for (int j = 0; j < scores[i].length; j++) { if (scores[i][j] == null) continue; for (int k = 0; k < scores[i][j].length; k++) { scores[i][j][k] = Math.log(scores[i][j][k]); } } } r.setScores2(scores); } /** * @param r */ private static void logarithmModeRule(UnaryRule r) { if (r == null || r.logarithmMode) return; r.logarithmMode = true; double[][] scores = r.getScores2(); for (int j = 0; j < scores.length; j++) { if (scores[j] == null) continue; for (int k = 0; k < scores[j].length; k++) { scores[j][k] = Math.log(scores[j][k]); } } r.setScores2(scores); } public boolean isLogarithmMode() { return logarithmMode; } public final boolean isGrammarTag(int n) { return isGrammarTag[n]; } public Grammar projectGrammar(double[] condProbs, int[][] fromMapping, int[][] toSubstateMapping) { short[] newNumSubStates = new short[numSubStates.length]; for (int state = 0; state < numSubStates.length; state++) { newNumSubStates[state] = (short) toSubstateMapping[state][0]; } Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); for (BinaryRule oldRule : binaryRuleMap.keySet()) { short pcS = oldRule.getParentState(), lcS = oldRule .getLeftChildState(), rcS = oldRule.getRightChildState(); double[][][] oldScores = oldRule.getScores2(); // merge binary rule double[][][] newScores = new double[newNumSubStates[lcS]][newNumSubStates[rcS]][newNumSubStates[pcS]]; for (int lS = 0; lS < numSubStates[lcS]; lS++) { for (int rS = 0; rS < numSubStates[rcS]; rS++) { if (oldScores[lS][rS] == null) continue; for (int pS = 0; pS < numSubStates[pcS]; pS++) { newScores[toSubstateMapping[lcS][lS + 1]][toSubstateMapping[rcS][rS + 1]][toSubstateMapping[pcS][pS + 1]] += condProbs[fromMapping[pcS][pS]] * oldScores[lS][rS][pS]; } } } BinaryRule newRule = new BinaryRule(oldRule, newScores); grammar.addBinary(newRule); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { short pcS = oldRule.getParentState(), ccS = oldRule.getChildState(); double[][] oldScores = oldRule.getScores2(); double[][] newScores = new double[newNumSubStates[ccS]][newNumSubStates[pcS]]; for (int cS = 0; cS < numSubStates[ccS]; cS++) { if (oldScores[cS] == null) continue; for (int pS = 0; pS < numSubStates[pcS]; pS++) { newScores[toSubstateMapping[ccS][cS + 1]][toSubstateMapping[pcS][pS + 1]] += condProbs[fromMapping[pcS][pS]] * oldScores[cS][pS]; } } UnaryRule newRule = new UnaryRule(oldRule, newScores); grammar.addUnary(newRule); // grammar.closedSumRulesWithParent[newRule.parentState].add(newRule); // grammar.closedSumRulesWithChild[newRule.childState].add(newRule); } grammar.computePairsOfUnaries(); // grammar.splitRules(); grammar.makeCRArrays(); grammar.isGrammarTag = this.isGrammarTag; // System.out.println(grammar.toString()); return grammar; } public Grammar copyGrammar(boolean noUnaryChains) { short[] newNumSubStates = numSubStates.clone(); Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); for (BinaryRule oldRule : binaryRuleMap.keySet()) { BinaryRule newRule = new BinaryRule(oldRule); grammar.addBinary(newRule); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { UnaryRule newRule = new UnaryRule(oldRule); grammar.addUnary(newRule); } if (noUnaryChains) { closedSumRulesWithParent = closedViterbiRulesWithParent = unaryRulesWithParent; closedSumRulesWithChild = closedViterbiRulesWithChild = unaryRulesWithC; } else grammar.computePairsOfUnaries(); grammar.makeCRArrays(); grammar.isGrammarTag = this.isGrammarTag; /* * grammar.ruleIndexer = ruleIndexer; grammar.startIndex = startIndex; * grammar.nEntries = nEntries; grammar.toBeIgnored = toBeIgnored; */ return grammar; } public Grammar projectTo0LevelGrammar(double[] condProbs, int[][] fromMapping, int[][] toMapping) { int newNumStates = fromMapping[fromMapping.length - 1][0]; // all rules have the same parent in this grammar double[][] newBinaryProbs = new double[newNumStates][newNumStates]; double[] newUnaryProbs = new double[newNumStates]; short[] newNumSubStates = new short[numSubStates.length]; Arrays.fill(newNumSubStates, (short) 1); Grammar grammar = new Grammar(newNumSubStates, findClosedPaths, smoother, this, threshold); // short[] newNumSubStates = new short[newNumStates]; // grammar.numSubStates = newNumSubStates; // grammar.numStates = (short)newNumStates; for (BinaryRule oldRule : binaryRuleMap.keySet()) { short pcS = oldRule.getParentState(), lcS = oldRule .getLeftChildState(), rcS = oldRule.getRightChildState(); double[][][] oldScores = oldRule.getScores2(); // merge binary rule // double[][][] newScores = new double[1][1][1]; for (int lS = 0; lS < numSubStates[lcS]; lS++) { for (int rS = 0; rS < numSubStates[rcS]; rS++) { if (oldScores[lS][rS] == null) continue; for (int pS = 0; pS < numSubStates[pcS]; pS++) { newBinaryProbs[toMapping[lcS][lS]][toMapping[rcS][rS]] += // newBinaryProbs[lcS][rcS] += condProbs[fromMapping[pcS][pS]] * oldScores[lS][rS][pS]; } } } // BinaryRule newRule = new BinaryRule(oldRule); // newRule.setScores2(newScores); // grammar.addBinary(newRule); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { short pcS = oldRule.getParentState(), ccS = oldRule.getChildState(); double[][] oldScores = oldRule.getScores2(); for (int cS = 0; cS < numSubStates[ccS]; cS++) { if (oldScores[cS] == null) continue; for (int pS = 0; pS < numSubStates[pcS]; pS++) { // newScores[0][0] += // condProbs[fromMapping[pcS][pS]]*oldScores[cS][pS]; newUnaryProbs[toMapping[ccS][cS]] += // newUnaryProbs[ccS] += condProbs[fromMapping[pcS][pS]] * oldScores[cS][pS]; } } // UnaryRule newRule = new UnaryRule(oldRule); // newRule.setScores2(newScores); // grammar.addUnary(newRule); // grammar.closedSumRulesWithParent[newRule.parentState].add(newRule); // grammar.closedSumRulesWithChild[newRule.childState].add(newRule); } for (short lS = 0; lS < newBinaryProbs.length; lS++) { for (short rS = 0; rS < newBinaryProbs.length; rS++) { if (newBinaryProbs[lS][rS] > 0) { double[][][] newScores = new double[1][1][1]; newScores[0][0][0] = newBinaryProbs[lS][rS]; BinaryRule newRule = new BinaryRule((short) 0, lS, rS, newScores); // newRule.setScores2(newScores); grammar.addBinary(newRule); } } } for (short cS = 0; cS < newUnaryProbs.length; cS++) { if (newUnaryProbs[cS] > 0) { double[][] newScores = new double[1][1]; newScores[0][0] = newUnaryProbs[cS]; UnaryRule newRule = new UnaryRule((short) 0, cS, newScores); // newRule.setScores2(newScores); grammar.addUnary(newRule); } } grammar.computePairsOfUnaries(); grammar.makeCRArrays(); grammar.isGrammarTag = this.isGrammarTag; // System.out.println(grammar.toString()); return grammar; } public double[] computeConditionalProbabilities(int[][] fromMapping, int[][] toMapping) { double[][] transitionProbs = computeProductionProbabilities(fromMapping); // System.out.println(ArrayUtil.toString(transitionProbs)); double[] expectedCounts = computeExpectedCounts(transitionProbs); // System.out.println(Arrays.toString(expectedCounts)); /* * for (int state=0; state<mapping.length-1; state++){ for (int * substate=0; substate<mapping[state].length; substate++){ * System.out.println * ((String)tagNumberer.object(state)+"_"+substate+" "+ * expectedCounts[mapping[state][substate]]); } } */ double[] condProbs = new double[expectedCounts.length]; for (int projectedState = 0; projectedState < toMapping[toMapping.length - 1][0]; projectedState++) { double sum = 0; for (int state = 0; state < fromMapping.length - 1; state++) { for (int substate = 0; substate < fromMapping[state].length; substate++) { if (toMapping[state][substate] == projectedState) sum += expectedCounts[fromMapping[state][substate]]; } } for (int state = 0; state < fromMapping.length - 1; state++) { for (int substate = 0; substate < fromMapping[state].length; substate++) { if (toMapping[state][substate] == projectedState) condProbs[fromMapping[state][substate]] = expectedCounts[fromMapping[state][substate]] / sum; } } } return condProbs; } public int[][] computeToMapping(int level, int[][] toSubstateMapping) { if (level == -1) return computeMapping(-1); short[] numSubStates = this.numSubStates; int[][] mapping = new int[numSubStates.length + 1][]; int k = 0; for (int state = 0; state < numSubStates.length; state++) { mapping[state] = new int[numSubStates[state]]; int oldVal = -1; for (int substate = 0; substate < numSubStates[state]; substate++) { if (substate != 0 && oldVal != toSubstateMapping[state][substate + 1]) k++; mapping[state][substate] = k; oldVal = toSubstateMapping[state][substate + 1]; } k++; } mapping[numSubStates.length] = new int[1]; mapping[numSubStates.length][0] = k; // System.out.println("The merged grammar will have "+k+" substates."); return mapping; } public int[][] computeMapping(int level) { // level -1 -> 0-bar states // level 0 -> x-bar states // level 1 -> each (state,substate) gets its own index short[] numSubStates = this.numSubStates; int[][] mapping = new int[numSubStates.length + 1][]; int k = 0; for (int state = 0; state < numSubStates.length; state++) { mapping[state] = new int[numSubStates[state]]; Arrays.fill(mapping[state], -1); // if (!grammar.isGrammarTag(state)) continue; for (int substate = 0; substate < numSubStates[state]; substate++) { if (level >= 1) mapping[state][substate] = k++; else if (level == -1) { if (this.isGrammarTag(state)) mapping[state][substate] = 0; else mapping[state][substate] = state; } else /* level==0 */ mapping[state][substate] = state; } } mapping[numSubStates.length] = new int[1]; mapping[numSubStates.length][0] = (level < 1) ? numSubStates.length : k; // System.out.println("The grammar has "+mapping[numSubStates.length][0]+" substates."); return mapping; } public int[][] computeSubstateMapping(int level) { // level 0 -> merge all substates // level 1 -> merge upto depth 1 -> keep upto 2 substates // level 2 -> merge upto depth 2 -> keep upto 4 substates short[] numSubStates = this.numSubStates; // for (int i=0; i<numSubStates.length; i++) // System.out.println(i+" "+numSubStates[i]+" "+splitTrees[i].toString()); int[][] mapping = new int[numSubStates.length][]; for (int state = 0; state < numSubStates.length; state++) { mapping[state] = new int[numSubStates[state] + 1]; int k = 0; if (level >= 0) { Arrays.fill(mapping[state], -1); Tree<Short> hierarchy = splitTrees[state]; List<Tree<Short>> subTrees = hierarchy.getAtDepth(level); for (Tree<Short> subTree : subTrees) { List<Short> leaves = subTree.getYield(); for (Short substate : leaves) { // System.out.println(substate+" "+numSubStates[state]+" "+state); if (substate == numSubStates[state]) System.out.print("Will crash."); mapping[state][substate + 1] = k; } k++; } } else { k = 1; } mapping[state][0] = k; } return mapping; } public void computeReverseSubstateMapping(int level, int[][] lChildMap, int[][] rChildMap) { // level 1 -> how do the states from depth 1 expand to depth 2 for (int state = 0; state < numSubStates.length; state++) { Tree<Short> hierarchy = splitTrees[state]; List<Tree<Short>> subTrees = hierarchy.getAtDepth(level); lChildMap[state] = new int[subTrees.size()]; rChildMap[state] = new int[subTrees.size()]; for (Tree<Short> subTree : subTrees) { int substate = subTree.getLabel(); if (subTree.isLeaf()) { lChildMap[state][substate] = substate; rChildMap[state][substate] = substate; continue; } boolean first = true; int nChildren = subTree.getChildren().size(); for (Tree<Short> child : subTree.getChildren()) { if (first) { lChildMap[state][substate] = child.getLabel(); first = false; } else rChildMap[state][substate] = child.getLabel(); if (nChildren == 1) rChildMap[state][substate] = child.getLabel(); } } } } private double[] computeExpectedCounts(double[][] transitionProbs) { // System.out.println(ArrayUtil.toString(transitionProbs)); double[] expectedCounts = new double[transitionProbs.length]; double[] tmpCounts = new double[transitionProbs.length]; expectedCounts[0] = 1; tmpCounts[0] = 1; // System.out.print("Computing expected counts"); int iter = 0; double diff = 1; double sum = 1; // 1 for the root while (diff > 1.0e-10 && iter < 50) { iter++; for (int state = 1; state < expectedCounts.length; state++) { for (int pState = 0; pState < expectedCounts.length; pState++) { tmpCounts[state] += expectedCounts[pState] * transitionProbs[pState][state]; } } diff = 0; sum = 1; for (int state = 1; state < expectedCounts.length; state++) { // tmpCounts[state] /= sum; diff += (Math.abs(expectedCounts[state] - tmpCounts[state])); expectedCounts[state] = tmpCounts[state]; sum += tmpCounts[state]; tmpCounts[state] = 0; } expectedCounts[0] = 1; tmpCounts[0] = 1; // System.out.println(Arrays.toString(tmpCounts)); // System.out.println(diff); // System.out.print("."); // System.out.print(diff); } // System.out.println("done.\nExpected total count: "+sum); // System.out.println(Arrays.toString(expectedCounts)); return expectedCounts; // System.out.println(grammar.toString()); } private double[][] computeProductionProbabilities(int[][] mapping) { short[] numSubStates = this.numSubStates; int totalStates = mapping[numSubStates.length][0]; // W_ij is the probability of state i producing state j double[][] W = new double[totalStates][totalStates]; for (int state = 0; state < numSubStates.length; state++) { // if (!grammar.isGrammarTag(state)) continue; BinaryRule[] parentRules = this.splitRulesWithP(state); for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; int lState = r.leftChildState; int rState = r.rightChildState; /* * if (lState==15||rState==15){ System.out.println("Found one"); * } */ double[][][] scores = r.getScores2(); for (int lS = 0; lS < numSubStates[lState]; lS++) { for (int rS = 0; rS < numSubStates[rState]; rS++) { if (scores[lS][rS] == null) continue; for (int pS = 0; pS < numSubStates[state]; pS++) { W[mapping[state][pS]][mapping[lState][lS]] += scores[lS][rS][pS]; W[mapping[state][pS]][mapping[rState][rS]] += scores[lS][rS][pS]; } } } } List<UnaryRule> uRules = this.getUnaryRulesByParent(state); for (UnaryRule r : uRules) { int cState = r.childState; if (cState == state) continue; /* * if (cState==15){ System.out.println("Found one"); } */ double[][] scores = r.getScores2(); for (int cS = 0; cS < numSubStates[cState]; cS++) { if (scores[cS] == null) continue; for (int pS = 0; pS < numSubStates[state]; pS++) { W[mapping[state][pS]][mapping[cState][cS]] += scores[cS][pS]; } } } } return W; } public void computeProperClosures() { int[][] map = new int[numStates][]; int index = 0; for (int state = 0; state < numStates; state++) { map[state] = new int[numSubStates[state]]; for (int substate = 0; substate < numSubStates[state]; substate++) { map[state][substate] = index++; } } double[][][] sumClosureMatrix = new double[10][index][index]; // initialize for (int parentState = 0; parentState < numStates; parentState++) { for (int i = 0; i < unaryRulesWithParent[parentState].size(); i++) { UnaryRule rule = unaryRulesWithParent[parentState].get(i); short childState = rule.getChildState(); double[][] scores = rule.getScores2(); for (int childSubState = 0; childSubState < numSubStates[childState]; childSubState++) { if (scores[childSubState] == null) continue; for (int parentSubState = 0; parentSubState < numSubStates[parentState]; parentSubState++) { sumClosureMatrix[0][map[parentState][parentSubState]][map[childState][childSubState]] = scores[childSubState][parentSubState]; } } } } // now loop until convergence = length 10 for now for (int length = 1; length < 10; length++) { for (short interState = 0; interState < numStates; interState++) { for (int i = 0; i < unaryRulesWithParent[interState].size(); i++) { UnaryRule rule = unaryRulesWithParent[interState].get(i); short endState = rule.getChildState(); double[][] scores = rule.getScores2(); // loop over substates for (int startState = 0; startState < numStates; startState++) { // we have a start and an end and need to loop over the // intermediate state,substates for (int startSubState = 0; startSubState < numSubStates[startState]; startSubState++) { for (int endSubState = 0; endSubState < numSubStates[endState]; endSubState++) { double ruleScore = 0; if (scores[endSubState] == null) continue; for (int interSubState = 0; interSubState < numSubStates[interState]; interSubState++) { ruleScore += sumClosureMatrix[length - 1][map[startState][startSubState]][map[interState][interSubState]] * scores[endSubState][interSubState]; } sumClosureMatrix[length][map[startState][startSubState]][map[endState][endSubState]] += ruleScore; } } } } } } // now sum up the paths of different lengths double[][] sumClosureScores = new double[index][index]; for (int length = 0; length < 10; length++) { for (int startState = 0; startState < index; startState++) { for (int endState = 0; endState < index; endState++) { sumClosureScores[startState][endState] += sumClosureMatrix[length][startState][endState]; } } } // reset the lists of unaries closedSumRulesWithParent = new List[numStates]; closedSumRulesWithChild = new List[numStates]; for (short startState = 0; startState < numStates; startState++) { closedSumRulesWithParent[startState] = new ArrayList<UnaryRule>(); closedSumRulesWithChild[startState] = new ArrayList<UnaryRule>(); } // finally create rules and add them to the arrays for (short startState = 0; startState < numStates; startState++) { for (short endState = 0; endState < numStates; endState++) { if (startState == endState) continue; boolean atLeastOneNonZero = false; double[][] scores = new double[numSubStates[endState]][numSubStates[startState]]; for (int startSubState = 0; startSubState < numSubStates[startState]; startSubState++) { for (int endSubState = 0; endSubState < numSubStates[endState]; endSubState++) { double score = sumClosureScores[map[startState][startSubState]][map[endState][endSubState]]; if (score > 0) { scores[endSubState][startSubState] = score; atLeastOneNonZero = true; } } } if (atLeastOneNonZero) { UnaryRule newUnary = new UnaryRule(startState, endState, scores); addUnary(newUnary); closedSumRulesWithParent[startState].add(newUnary); closedSumRulesWithChild[endState].add(newUnary); } } } if (closedSumRulesWithP == null) { closedSumRulesWithP = new UnaryRule[numStates][]; closedSumRulesWithC = new UnaryRule[numStates][]; } for (int i = 0; i < numStates; i++) { closedSumRulesWithP[i] = (UnaryRule[]) closedSumRulesWithParent[i] .toArray(new UnaryRule[0]); closedSumRulesWithC[i] = (UnaryRule[]) closedSumRulesWithChild[i] .toArray(new UnaryRule[0]); } } /** * @param output */ public void writeSplitTrees(Writer w) { PrintWriter out = new PrintWriter(w); for (int state = 1; state < numStates; state++) { String tag = (String) tagNumberer.object(state); if (isGrammarTag[state] && tag.endsWith("^g")) tag = tag.substring(0, tag.length() - 2); out.write(tag + "\t" + splitTrees[state].toString() + "\n"); } out.flush(); out.close(); } public int[][] getClosedSumPaths() { return closedSumPaths; } public void overwriteWithMaxent() { //this is the map which maps each subrule to its location in binary rule HashMap<String, Pair<BinaryRule, Triple>> str2brule = new HashMap<String, Pair<BinaryRule,Triple>>(); //create a eventstream BinaryRuleEventStream binaryRuleEventStream = new BinaryRuleEventStream(); for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] score2 = binaryRuleCounter.getCount(binaryRule); for (Triple triple : binaryRule.getAllSubRules(score2)) { String bruleStr = binaryRule.getStrSubRule(triple); str2brule.put(bruleStr, new Pair<BinaryRule, Triple>(binaryRule, triple)); for (int i=0; i<Math.round(1+binaryRule.getCountForSubRule(score2, triple)); i++) { binaryRuleEventStream.addEvent(bruleStr, BinaryRuleEventStream.getContext(bruleStr)); } //reset this count to zero binaryRule.setProbForSubRule(score2, triple, 0.0); } } //now train the maxent model GISModel bruleMaxentModel = GIS.trainModel(binaryRuleEventStream, 10, 0); for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double[][][] score2 = binaryRuleCounter.getCount(binaryRule); for (Triple triple : binaryRule.getAllSubRules(score2)) { String bruleStr = binaryRule.getStrSubRule(triple); double[] outcomes = bruleMaxentModel.eval(BinaryRuleEventStream.getContext(bruleStr)); for (int i = 0; i < outcomes.length; i++) { Pair<BinaryRule, Triple> pair = str2brule.get(bruleMaxentModel.getOutcome(i)); //select the outcomes which match this binaryrule (ignoring all others) pair.getFirst().incProbForSubRule(score2, pair.getSecond(), outcomes[i]); } } } //this is the map which maps each subrule to its location in binary rule HashMap<String, Pair<UnaryRule, Pair<Integer, Integer>>> str2urule = new HashMap<String, Pair<UnaryRule,Pair<Integer, Integer>>>(); //create a eventstream UnaryRuleEventStream unaryRuleEventStream = new UnaryRuleEventStream(); for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] scores2 = unaryRuleCounter.getCount(unaryRule); for (Pair<Integer, Integer> pair : unaryRule.getAllSubRules(scores2)) { String uruleStr = unaryRule.getStrSubRule(pair); str2urule.put(uruleStr, new Pair<UnaryRule, Pair<Integer, Integer>>(unaryRule, pair)); for (int i=0; i<Math.round(1+unaryRule.getCountForSubRule(scores2, pair)); i++) { unaryRuleEventStream.addEvent(uruleStr, UnaryRuleEventStream.getContext(uruleStr)); } //reset this count to zero unaryRule.setProbForSubRule(scores2, pair, 0.0); } } //now train the maxent model GISModel uruleMaxentModel = GIS.trainModel(unaryRuleEventStream); for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double[][] scores2 = unaryRuleCounter.getCount(unaryRule); for (Pair<Integer, Integer> pair : unaryRule.getAllSubRules(scores2)) { String uruleStr = unaryRule.getStrSubRule(pair); double[] outcomes = uruleMaxentModel.eval(UnaryRuleEventStream.getContext(uruleStr)); for (int i = 0; i < outcomes.length; i++) { Pair<UnaryRule, Pair<Integer, Integer>> pair2 = str2urule.get(uruleMaxentModel.getOutcome(i)); //select the outcomes which match this binaryrule (ignoring all others) pair2.getFirst().incProbForSubRule(scores2, pair2.getSecond(), outcomes[i]); } } } } }