package edu.berkeley.nlp.PCFGLA; import java.util.ArrayList; import java.util.List; import java.util.Random; import edu.berkeley.nlp.util.*; /** * Unary Rules (with ints for parent and child) * * @author Dan Klein */ public class UnaryRule extends Rule implements java.io.Serializable, Comparable { public short childState = -1; /** * NEW: * scores[childSubState][parentSubState] */ public double[][] scores; /* public UnaryRule(String s, Numberer n) { String[] fields = StringUtils.splitOnCharWithQuoting(s, ' ', '\"', '\\'); // System.out.println("fields:\n" + fields[0] + "\n" + fields[2] + "\n" + fields[3]); this.parent = n.number(fields[0]); this.child = n.number(fields[2]); this.score = Double.parseDouble(fields[3]); } */ public UnaryRule(short pState, short cState, double[][] scores) { this.parentState = pState; this.childState = cState; this.scores = scores; } public UnaryRule(short pState, short cState) { this.parentState = pState; this.childState = cState; // this.scores = new double[1][1]; } /** Copy constructor */ public UnaryRule(UnaryRule u) { this(u.parentState,u.childState,ArrayUtil.copy(u.scores)); } public UnaryRule(UnaryRule u,double[][] newScores) { this(u.parentState,u.childState,newScores); } public UnaryRule(short pState, short cState, short pSubStates, short cSubStates) { this.parentState = pState; this.childState = cState; this.scores = new double[cSubStates][pSubStates]; } public boolean isUnary() { return true; } public int hashCode() { return ((int)parentState << 18) ^ ((int)childState); } public boolean equals(Object o) { if (this == o) { return true; } if (o instanceof UnaryRule) { UnaryRule ur = (UnaryRule) o; if (parentState == ur.parentState && childState == ur.childState) { return true; } } return false; } public int compareTo(Object o) { UnaryRule ur = (UnaryRule) o; if (parentState < ur.parentState) { return -1; } if (parentState > ur.parentState) { return 1; } if (childState < ur.childState) { return -1; } if (childState > ur.childState) { return 1; } return 0; } private static final char[] charsToEscape = new char[]{'\"'}; public String toString() { Numberer n = Numberer.getGlobalNumberer("tags"); String cState = (String)n.object(childState); if (cState.endsWith("^g")) cState = cState.substring(0,cState.length()-2); String pState = (String)n.object(parentState); if (pState.endsWith("^g")) pState = pState.substring(0,pState.length()-2); if (scores==null) return pState+" -> "+cState+"\n"; StringBuilder sb = new StringBuilder(); for (int cS=0; cS<scores.length; cS++){ if (scores[cS]==null) continue; for (int pS=0; pS<scores[cS].length; pS++){ double p = scores[cS][pS]; if (p>0) sb.append(pState+"_"+pS+ " -> " + cState+"_"+cS +" "+p+"\n"); } } return sb.toString(); } //TODO : fix to create only for different LHS public List<Pair<Integer, Integer>> getAllSubRules(double[][] scores2) { List<Pair<Integer, Integer>> subrules = new ArrayList<Pair<Integer,Integer>>(); for (int cS=0; cS<scores2.length; cS++){ if (scores2[cS] == null) continue; for (int pS=0; pS<scores2[cS].length; pS++){ subrules.add(new Pair<Integer, Integer>(cS, pS)); } } return subrules; } public String getStrSubRule(Pair<Integer, Integer> pair) { Numberer n = Numberer.getGlobalNumberer("tags"); String cState = (String)n.object(childState); if (cState.endsWith("^g")) cState = cState.substring(0,cState.length()-2); String pState = (String)n.object(parentState); if (pState.endsWith("^g")) pState = pState.substring(0,pState.length()-2); if (scores==null) return pState+" -> "+cState+"\n"; StringBuilder sb = new StringBuilder(); sb.append(pState+"_"+ pair.getSecond() + " -> " + cState+"_"+ pair.getFirst()); return sb.toString(); } public double getCountForSubRule(double[][] scores2, Pair<Integer, Integer> pair) { return scores2[pair.getFirst()][pair.getSecond()]; } public void setProbForSubRule(double[][] scores2, Pair<Integer, Integer> pair, double prob) { if (scores2 != null && pair.getFirst() < scores2.length && scores2[pair.getFirst()] != null && pair.getSecond() < scores2[pair.getFirst()].length) scores2[pair.getFirst()][pair.getSecond()] = prob; } public void incProbForSubRule(double[][] scores2, Pair<Integer, Integer> pair, double prob) { if (scores2 != null && pair.getFirst() < scores2.length && scores2[pair.getFirst()] != null && pair.getSecond() < scores2[pair.getFirst()].length) scores2[pair.getFirst()][pair.getSecond()] += prob; } public String toStringSEIE() { Numberer n = Numberer.getGlobalNumberer("tags"); String cState = (String)n.object(childState); if (cState.endsWith("^g")) cState = cState.substring(0,cState.length()-2); String pState = (String)n.object(parentState); if (pState.endsWith("^g")) pState = pState.substring(0,pState.length()-2); if (scores==null) return pState+" -> "+cState+"\n"; StringBuilder sb = new StringBuilder(); for (int cS=0; cS<scores.length; cS++){ if (scores[cS]==null) continue; for (int pS=0; pS<scores[cS].length; pS++){ sb.append(pState+"_"+pS+ " -> " + cState+"_"+cS+" "); } } return sb.toString(); } public String toString_old() { Numberer n = Numberer.getGlobalNumberer("tags"); return "\"" + StringUtils.escapeString(n.object(parentState).toString(), charsToEscape, '\\') + "\" -> \"" + StringUtils.escapeString(n.object(childState).toString(), charsToEscape, '\\') + "\" " + ArrayUtil.toString(scores); } public short getChildState() { return childState; } public void setScore(int pS, int cS, double score){ // sets the score for a particular combination of substates scores[cS][pS] = score; } public double getScore(int pS, int cS){ // gets the score for a particular combination of substates if (scores[cS]==null) { if (logarithmMode) return Double.NEGATIVE_INFINITY; return 0; } return scores[cS][pS]; } public void setScores2(double[][] scores){ this.scores = scores; } /** scores[parentSubState][childSubState] */ public double[][] getScores2(){ return scores; } public void setNodes(short pState, short cState){ this.parentState = pState; this.childState = cState; } private static final long serialVersionUID = 2L; /** * @return */ public UnaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) { // when splitting on parent, never split on ROOT parent short parentSplitFactor = this.getParentState() == 0 ? (short)1 : (short)2; if (newNumSubStates[this.parentState]==numSubStates[this.parentState]){parentSplitFactor=1;} int childSplitFactor = 2; if (newNumSubStates[this.childState]==numSubStates[this.childState]){childSplitFactor=1;} double[][] oldScores = this.getScores2(); double[][] newScores = new double[newNumSubStates[this.childState]][]; //for all current substates for (short cS = 0; cS < oldScores.length; cS++) { if (oldScores[cS]==null) continue; for (short c = 0; c < childSplitFactor; c++) { short newCS = (short)(childSplitFactor * cS + c); newScores[newCS]= new double[newNumSubStates[this.parentState]]; } for (short pS = 0; pS < oldScores[cS].length; pS++) { double score = oldScores[cS][pS]; //split on parent for (short p = 0; p < parentSplitFactor; p++) { double divFactor = (doNotNormalize) ? 1.0 : childSplitFactor; double randomComponent = score / divFactor * randomness / 100 * (random.nextDouble() - 0.5); // split on child for (short c = 0; c < childSplitFactor; c++) { if (c == 1) { randomComponent *= -1; } if (childSplitFactor==1){ randomComponent=0; } // divide score by divFactor because we're splitting each rule in 1/divFactor short newPS = (short)(parentSplitFactor * pS + p); short newCS = (short)(childSplitFactor * cS + c); double splitFactor = (doNotNormalize) ? 1.0 : childSplitFactor; newScores[newCS][newPS] = (score / splitFactor + randomComponent); // sparsifier.splitUnaryWeight( // oldRule.getParentState(), cS, oldRule.getChildState(), pS, // newPS, newCS, childSplitFactor, randomComponent, score, tagNumberer); if (mode==2) newScores[newCS][newPS] = 1.0+random.nextDouble()/100.0; } } } } return new UnaryRule(this,newScores); } }