/** * */ package edu.berkeley.nlp.PCFGLA.smoothing; import java.io.Serializable; import java.util.List; import edu.berkeley.nlp.PCFGLA.BinaryCounterTable; import edu.berkeley.nlp.PCFGLA.BinaryRule; import edu.berkeley.nlp.PCFGLA.UnaryCounterTable; import edu.berkeley.nlp.PCFGLA.UnaryRule; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Numberer; /** * @author leon * */ public class SmoothAcrossParentBits implements Smoother, Serializable { /** * */ private static final long serialVersionUID = 1L; double same; double[][][] diffWeights; double weightBasis = 0.5; double totalWeight; public SmoothAcrossParentBits copy(){ return new SmoothAcrossParentBits(same,diffWeights,weightBasis,totalWeight); } public SmoothAcrossParentBits(double smooth, Tree<Short>[] splitTrees) { // does not smooth across top-level split, otherwise smooths uniformly same = 1-smooth; //int maxNBits = (int)Math.round(Math.log(maxSubstates)/Math.log(2)); int nStates = splitTrees.length; diffWeights = new double [nStates][][]; for (short state=0; state<nStates; state++){ Tree<Short> splitTree = splitTrees[state]; List<Short> allSubstates = splitTree.getYield(); int nSubstates = 1; for (int i=0; i<allSubstates.size(); i++){ if (allSubstates.get(i)>=nSubstates) nSubstates = allSubstates.get(i)+1; } diffWeights[state] = new double[nSubstates][nSubstates]; if (nSubstates==1){ // state has only one substate -> no smoothing diffWeights[state][0][0] = 1.0; } else { // smooth only with ones in the same top-level branch // TODO: weighted smoothing // descend down to first split first while (splitTree.getChildren().size()==1) { splitTree = splitTree.getChildren().get(0); } // for (short substate=0; substate<nSubstates; substate++){ // for (int branch=0; branch<2; branch++){ // List<Short> substatesInBranch = splitTree.getChildren().get(branch).getYield(); // if (substatesInBranch.contains(substate)){ // totalWeight = 0; // fillWeightsArray(state,substate,1.0,splitTree.getChildren().get(branch)); // // normalize the weights // if (totalWeight==0) continue; // for (short substate2 = 0; substate2<nSubstates; substate2++){ // if (substate==substate2) continue; // diffWeights[state][substate][substate2] /= totalWeight; // diffWeights[state][substate][substate2] *= smooth; // } // } // //else - dont smooth across top-level branch // } // } for (int branch=0; branch<2; branch++){ // compute weights for substates in top-level branch List<Short> substatesInBranch = splitTree.getChildren().get(branch).getYield(); int total = substatesInBranch.size(); double normalizedSmooth = smooth/(total-1); for (short i : substatesInBranch) { for (short j : substatesInBranch) { if (i==j) { diffWeights[state][i][j] = same; } else { diffWeights[state][i][j] = normalizedSmooth; } } } } } } /* diffWeights = new double[maxNBits+1]; for (int i=0; i<=maxNBits; i++) { diffWeights[i] = Math.pow(2,-i+1)*smooth/maxNBits; }*/ } /** * @param same2 * @param diffWeights2 * @param weightBasis2 * @param totalWeight2 */ public SmoothAcrossParentBits(double same2, double[][][] diffWeights2, double weightBasis2, double totalWeight2) { this.same = same2; this.diffWeights = diffWeights2; this.weightBasis = weightBasis2; this.totalWeight = totalWeight2; } /* (non-Javadoc) * @see edu.berkeley.nlp.PCFGLA.smoothing.Smoother#smooth(edu.berkeley.nlp.util.UnaryCounterTable, edu.berkeley.nlp.util.BinaryCounterTable) */ public void smooth(UnaryCounterTable unaryCounter, BinaryCounterTable binaryCounter) { for (UnaryRule r : unaryCounter.keySet()) { double[][] scores = unaryCounter.getCount(r); double[][] scopy = new double[scores.length][]; short pState = r.parentState; for (int j=0; j<scores.length; j++) { if( scores[j]==null ) continue; // nothing to smooth scopy[j] = new double[scores[j].length]; for (int i=0; i<scores[j].length; i++) { for (int k=0; k<scores[j].length; k++) { scopy[j][i] += diffWeights[pState][i][k] * scores[j][k]; } } } unaryCounter.setCount(r,scopy); } for (BinaryRule r : binaryCounter.keySet()) { double[][][] scores = binaryCounter.getCount(r); double[][][] scopy = new double[scores.length][scores[0].length][]; short pState = r.parentState; for (int j=0; j<scores.length; j++) { for (int l=0; l<scores[j].length; l++) { if (scores[j][l]==null) continue; //nothing to smooth scopy[j][l] = new double[scores[j][l].length]; for (int i=0; i<scores[j][l].length; i++) { for (int k=0; k<scores[j][l].length; k++) { scopy[j][l][i] += diffWeights[pState][i][k] * scores[j][l][k]; } } } } binaryCounter.setCount(r,scopy); } } private void fillWeightsArray(short state, short substate, double weight, Tree<Short> subTree){ if (subTree.isLeaf()){ if (subTree.getLabel()==substate) diffWeights[state][substate][substate] = same; else { diffWeights[state][substate][subTree.getLabel()] = weight; totalWeight+=weight;} return; } if (subTree.getChildren().size()==1) { fillWeightsArray(state,substate,weight,subTree.getChildren().get(0)); return; } for (int branch=0; branch<2; branch++) { Tree<Short> branchTree = subTree.getChildren().get(branch); List<Short> substatesInBranch = branchTree.getYield(); //int nSubstatesInBranch = substatesInBranch.size(); if (substatesInBranch.contains(substate)) fillWeightsArray(state,substate,weight,branchTree); else fillWeightsArray(state,substate,weight*weightBasis/2.0,branchTree); } } /* (non-Javadoc) * @see edu.berkeley.nlp.PCFGLA.smoothing.Smoother#smooth(short, float[]) */ public void smooth(short tag, double[] scores) { double[] scopy = new double[scores.length]; for (int i=0; i<scores.length; i++) { for (int k=0; k<scores.length; k++) { scopy[i] += diffWeights[tag][i][k] * scores[k]; } } for (int i=0; i<scores.length; i++) { // if (scores[i]==0) continue; scores[i] = scopy[i]; } } /* (non-Javadoc) * @see edu.berkeley.nlp.PCFGLA.smoothing.Smoother#updateWeights(int[][]) */ public void updateWeights(int[][] toSubstateMapping) { double[][][] newWeights = new double[toSubstateMapping.length][][]; for (int state=0; state<toSubstateMapping.length; state++){ int nSub = toSubstateMapping[state][0]; newWeights[state] = new double[nSub][nSub]; if (nSub==1) { newWeights[state][0][0] = 1.0; continue; } double[] total = new double[nSub]; for (int substate1=0; substate1<diffWeights[state].length; substate1++){ for (int substate2=0; substate2<diffWeights[state].length; substate2++){ newWeights[state][toSubstateMapping[state][substate1+1]][toSubstateMapping[state][substate2+1]] += diffWeights[state][substate1][substate2]; total[toSubstateMapping[state][substate1+1]] += diffWeights[state][substate1][substate2]; } } for (int substate1=0; substate1<nSub; substate1++){ for (int substate2=0; substate2<nSub; substate2++){ newWeights[state][substate1][substate2] /= total[substate1]; } } } diffWeights = newWeights; } /* (non-Javadoc) * @see edu.berkeley.nlp.PCFGLA.smoothing.Smoother#remapStates(edu.berkeley.nlp.util.Numberer, edu.berkeley.nlp.util.Numberer) */ public Smoother remapStates(Numberer thisNumberer, Numberer newNumberer) { SmoothAcrossParentBits remappedSmoother = copy(); remappedSmoother.diffWeights = new double[newNumberer.size()][][]; for (int s=0; s<newNumberer.size(); s++) { int translatedState = translateState(s, newNumberer, thisNumberer); if (translatedState >= 0) { remappedSmoother.diffWeights[s] = diffWeights[translatedState]; } else { remappedSmoother.diffWeights[s] = new double[1][1]; } } return remappedSmoother; } private short translateState(int state, Numberer baseNumberer, Numberer translationNumberer) { Object object = baseNumberer.object(state); if (translationNumberer.hasSeen(object)) { return (short)translationNumberer.number(object); } else { return (short)-1; } } }