package edu.berkeley.nlp.PCFGLA; import edu.berkeley.nlp.util.*; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * Binary rules (ints for parent, left and right children) * * @author Dan Klein */ public class BinaryRule extends Rule implements Serializable, java.lang.Comparable { public short leftChildState = -1; public short rightChildState = -1; /** * NEW: scores[leftSubState][rightSubState][parentSubState] gives score for * this rule */ public double[][][] scores; /** * Creates a BinaryRule from String s, assuming it was created using * toString(). * * @param s */ /* * public BinaryRule(String s, Numberer n) { String[] fields = * StringUtils.splitOnCharWithQuoting(s, ' ', '\"', '\\'); // * System.out.println("fields:\n" + fields[0] + "\n" + fields[2] + "\n" + * fields[3] + "\n" + fields[4]); this.parent = n.number(fields[0]); * this.leftChild = n.number(fields[2]); this.rightChild = * n.number(fields[3]); this.score = Double.parseDouble(fields[4]); } */ public BinaryRule(short pState, short lState, short rState, double[][][] scores) { this.parentState = pState; this.leftChildState = lState; this.rightChildState = rState; this.scores = scores; } public BinaryRule(short pState, short lState, short rState) { this.parentState = pState; this.leftChildState = lState; this.rightChildState = rState; // this.scores = new double[1][1][1]; } /** Copy constructor */ public BinaryRule(BinaryRule b) { this(b.parentState, b.leftChildState, b.rightChildState, ArrayUtil .copy(b.scores)); } public BinaryRule(BinaryRule b, double[][][] newScores) { this(b.parentState, b.leftChildState, b.rightChildState, newScores); } // public BinaryRule(short pState, short lState, short rState, short // pSubStates, int lSubStates, int rSubStates) { // this.parentState = pState; // this.leftChildState = lState; // this.rightChildState = rState; // this.scores = new double[lSubStates][rSubStates][pSubStates]; // } public int hashCode() { return ((int) parentState << 16) ^ ((int) leftChildState << 8) ^ ((int) rightChildState); } public boolean equals(Object o) { if (this == o) { return true; } if (o instanceof BinaryRule) { BinaryRule br = (BinaryRule) o; if (parentState == br.parentState && leftChildState == br.leftChildState && rightChildState == br.rightChildState) { return true; } } return false; } private static final char[] charsToEscape = new char[] { '\"' }; public String toString() { Numberer n = Numberer.getGlobalNumberer("tags"); String lState = (String) n.object(leftChildState); if (lState.endsWith("^g")) lState = lState.substring(0, lState.length() - 2); String rState = (String) n.object(rightChildState); if (rState.endsWith("^g")) rState = rState.substring(0, rState.length() - 2); String pState = (String) n.object(parentState); if (pState.endsWith("^g")) pState = pState.substring(0, pState.length() - 2); StringBuilder sb = new StringBuilder(); if (scores == null) return pState + " -> " + lState + " " + rState + "\n"; // sb.append(pState+ " -> "+lState+ " "+rState+ "\n"); for (int lS = 0; lS < scores.length; lS++) { for (int rS = 0; rS < scores[lS].length; rS++) { if (scores[lS][rS] == null) continue; for (int pS = 0; pS < scores[lS][rS].length; pS++) { double p = scores[lS][rS][pS]; if (p > 0) sb .append(pState + "_" + pS + " -> " + lState + "_" + lS + " " + rState + "_" + rS + " " + p + "\n"); } } } return sb.toString(); } //TODO : fix to create only for different LHS public List<Triple> getAllSubRules(double[][][] score2) { List<Triple> subrules = new ArrayList<Triple>(); for (int lS = 0; lS < score2.length; lS++) { for (int rS = 0; rS < score2[lS].length; rS++) { if (score2[lS][rS] == null) continue; for (int pS = 0; pS < score2[lS][rS].length; pS++) { subrules.add(new Triple(lS, rS, pS)); } } } return subrules; } public String getStrSubRule(Triple triple) { Numberer n = Numberer.getGlobalNumberer("tags"); String lState = (String) n.object(leftChildState); if (lState.endsWith("^g")) lState = lState.substring(0, lState.length() - 2); String rState = (String) n.object(rightChildState); if (rState.endsWith("^g")) rState = rState.substring(0, rState.length() - 2); String pState = (String) n.object(parentState); if (pState.endsWith("^g")) pState = pState.substring(0, pState.length() - 2); StringBuilder sb = new StringBuilder(); sb.append(pState + "_" + triple.parent + " -> " + lState + "_" + triple.leftchild + " " + rState + "_" + triple.rightchild); return sb.toString(); } public double getCountForSubRule(double[][][] score2, Triple triple) { return score2[triple.leftchild][triple.rightchild][triple.parent]; } public void setProbForSubRule(double[][][] score2, Triple triple, double prob) { if (score2 != null && triple.leftchild < score2.length && score2[triple.leftchild] != null && triple.rightchild < score2[triple.leftchild].length && score2[triple.leftchild][triple.rightchild] != null && triple.parent < score2[triple.leftchild][triple.rightchild].length) score2[triple.leftchild][triple.rightchild][triple.parent] = prob; } public void incProbForSubRule(double[][][] score2, Triple triple, double prob) { if (score2 != null && triple.leftchild < score2.length && score2[triple.leftchild] != null && triple.rightchild < score2[triple.leftchild].length && score2[triple.leftchild][triple.rightchild] != null && triple.parent < score2[triple.leftchild][triple.rightchild].length) score2[triple.leftchild][triple.rightchild][triple.parent] += prob; } public String toStringSEIE() { Numberer n = Numberer.getGlobalNumberer("tags"); String lState = (String) n.object(leftChildState); if (lState.endsWith("^g")) lState = lState.substring(0, lState.length() - 2); String rState = (String) n.object(rightChildState); if (rState.endsWith("^g")) rState = rState.substring(0, rState.length() - 2); String pState = (String) n.object(parentState); if (pState.endsWith("^g")) pState = pState.substring(0, pState.length() - 2); StringBuilder sb = new StringBuilder(); if (scores == null) return pState + " -> " + lState + " " + rState + "\n"; // sb.append(pState+ " -> "+lState+ " "+rState+ "\n"); for (int lS = 0; lS < scores.length; lS++) { for (int rS = 0; rS < scores[lS].length; rS++) { if (scores[lS][rS] == null) continue; for (int pS = 0; pS < scores[lS][rS].length; pS++) { sb.append(pState + "_" + pS + " -> " + lState + "_" + lS + " " + rState + "_" + rS + " "); } } } return sb.toString(); } public String toString_old() { Numberer n = Numberer.getGlobalNumberer("tags"); return "\"" + StringUtils.escapeString(n.object(parentState).toString(), charsToEscape, '\\') + "\" -> \"" + StringUtils.escapeString(n.object(leftChildState).toString(), charsToEscape, '\\') + "\" \"" + StringUtils.escapeString( n.object(rightChildState).toString(), charsToEscape, '\\') + "\" " + ArrayUtil.toString(scores); } public int compareTo(Object o) { BinaryRule ur = (BinaryRule) o; if (parentState < ur.parentState) { return -1; } if (parentState > ur.parentState) { return 1; } if (leftChildState < ur.leftChildState) { return -1; } if (leftChildState > ur.leftChildState) { return 1; } if (rightChildState < ur.rightChildState) { return -1; } if (rightChildState > ur.rightChildState) { return 1; } return 0; } public short getLeftChildState() { return leftChildState; } public short getRightChildState() { return rightChildState; } // public void setScore(int pS, int lS, int rS, double score){ // // sets the score for a particular combination of substates // scores[lS][rS][pS] = score; // } public double getScore(int pS, int lS, int rS) { // gets the score for a particular combination of substates if (scores[lS][rS] == null) { if (logarithmMode) return Double.NEGATIVE_INFINITY; return 0; } return scores[lS][rS][pS]; } public void setScores2(double[][][] scores) { this.scores = scores; } /** * scores[parentSubState][leftSubState][rightSubState] gives score for this * rule */ public double[][][] getScores2() { return scores; } public void setNodes(short pState, short lState, short rState) { this.parentState = pState; this.leftChildState = lState; this.rightChildState = rState; } private static final long serialVersionUID = 2L; public BinaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) { // when splitting on parent, never split on ROOT int parentSplitFactor = this.getParentState() == 0 ? 1 : 2; // should if (newNumSubStates[this.parentState] == numSubStates[this.parentState]) { parentSplitFactor = 1; } int lChildSplitFactor = 2; if (newNumSubStates[this.leftChildState] == numSubStates[this.leftChildState]) { lChildSplitFactor = 1; } int rChildSplitFactor = 2; if (newNumSubStates[this.rightChildState] == numSubStates[this.rightChildState]) { rChildSplitFactor = 1; } double[][][] oldScores = this.getScores2(); double[][][] newScores = new double[oldScores.length * lChildSplitFactor][oldScores[0].length * rChildSplitFactor][]; // [oldScores[0][0].length * parentSplitFactor]; // Arrays.fill(newScores,Double.NEGATIVE_INFINITY); // for all current substates for (short lcS = 0; lcS < oldScores.length; lcS++) { for (short rcS = 0; rcS < oldScores[0].length; rcS++) { if (oldScores[lcS][rcS] == null) continue; for (short lc = 0; lc < lChildSplitFactor; lc++) { for (short rc = 0; rc < rChildSplitFactor; rc++) { short newLCS = (short) (lChildSplitFactor * lcS + lc); short newRCS = (short) (rChildSplitFactor * rcS + rc); newScores[newLCS][newRCS] = new double[newNumSubStates[this.parentState]]; } } for (short pS = 0; pS < oldScores[lcS][rcS].length; pS++) { double score = oldScores[lcS][rcS][pS]; // split on parent for (short p = 0; p < parentSplitFactor; p++) { double divFactor = (doNotNormalize) ? 1.0 : lChildSplitFactor * rChildSplitFactor; double randomComponentLC = score / divFactor * randomness / 100 * (random.nextDouble() - 0.5); // split on left child for (short lc = 0; lc < lChildSplitFactor; lc++) { // reverse the random component for half of the // rules if (lc == 1) { randomComponentLC *= -1; } // don't add randomness if we're not splitting if (lChildSplitFactor == 1) { randomComponentLC = 0; } double randomComponentRC = score / divFactor * randomness / 100 * (random.nextDouble() - 0.5); // split on right child for (short rc = 0; rc < rChildSplitFactor; rc++) { // reverse the random component for half of the // rules if (rc == 1) { randomComponentRC *= -1; } // don't add randomness if we're not splitting if (rChildSplitFactor == 1) { randomComponentRC = 0; } // set new score; divide score by 4 because // we're dividing each // binary rule under a parent into 4 short newPS = (short) (parentSplitFactor * pS + p); short newLCS = (short) (lChildSplitFactor * lcS + lc); short newRCS = (short) (rChildSplitFactor * rcS + rc); double splitFactor = (doNotNormalize) ? 1.0 : lChildSplitFactor * rChildSplitFactor; newScores[newLCS][newRCS][newPS] = (score / (splitFactor) + randomComponentLC + randomComponentRC); // sparsifier // .splitBinaryWeight(oldRule.getParentState(), // pS, // oldRule.getLeftChildState(), lcS, oldRule // .getRightChildState(), rcS, newPS, newLCS, // newRCS, lChildSplitFactor, rChildSplitFactor, // randomComponentLC, randomComponentRC, , // tagNumberer); if (mode == 2) newScores[newLCS][newRCS][newPS] = 1.0 + random .nextDouble() / 100.0; } } } } } } BinaryRule newRule = new BinaryRule(this, newScores); return newRule; } }