package cc.mallet.fst; import java.util.ArrayList; import java.util.logging.Level; import java.util.logging.Logger; import cc.mallet.fst.Transducer.State; import cc.mallet.fst.Transducer.TransitionIterator; import cc.mallet.types.DenseVector; import cc.mallet.types.LabelAlphabet; import cc.mallet.types.LabelVector; import cc.mallet.types.MatrixOps; import cc.mallet.types.Sequence; import cc.mallet.types.SequencePair; import cc.mallet.util.MalletLogger; //****************************************************************************** //CPAL - NEW "BEAM" Version of Forward Backward //****************************************************************************** public class SumLatticeBeam implements SumLattice // CPAL - like Lattice but using max-product to get the viterbiPath { // CPAL - these worked well for nettalk //private int beamWidth = 10; //private double KLeps = .005; boolean UseForwardBackwardBeam = false; protected static int beamWidth = 3; private double KLeps = 0; private double Rmin = 0.1; private double nstatesExpl[]; private int curIter = 0; int tctIter = 0; // The number of times we have been called this iteration private double curAvgNstatesExpl; public int getBeamWidth () { return beamWidth; } public void setBeamWidth (int beamWidth) { this.beamWidth = beamWidth; } public int getTctIter(){ return this.tctIter; } public void setCurIter (int curIter) { this.curIter = curIter; this.tctIter = 0; } public void incIter () { this.tctIter++; } public void setKLeps (double KLeps) { this.KLeps = KLeps; } public void setRmin (double Rmin) { this.Rmin = Rmin; } public double[] getNstatesExpl() { return nstatesExpl; } public boolean getUseForwardBackwardBeam(){ return this.UseForwardBackwardBeam; } public void setUseForwardBackwardBeam (boolean state) { this.UseForwardBackwardBeam = state; } private static Logger logger = MalletLogger.getLogger(SumLatticeBeam.class.getName()); // "ip" == "input position", "op" == "output position", "i" == "state index" Transducer t; double weight; Sequence input, output; LatticeNode[][] nodes; // indexed by ip,i int latticeLength; int curBeamWidth; // CPAL - can be adapted if maximizer is confused // xxx Now that we are incrementing here directly, there isn't // necessarily a need to save all these arrays... // log(probability) of being in state "i" at input position "ip" double[][] gammas; // indexed by ip,i double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true; LabelVector labelings[]; // indexed by op, created only if "outputAlphabet" is non-null in constructor private LatticeNode getLatticeNode (int ip, int stateIndex) { if (nodes[ip][stateIndex] == null) nodes[ip][stateIndex] = new LatticeNode (ip, t.getState (stateIndex)); return nodes[ip][stateIndex]; } // You may pass null for output, meaning that the lattice // is not constrained to match the output public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor) { this (t, input, output, incrementor, false, null); } // You may pass null for output, meaning that the lattice // is not constrained to match the output public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis) { this (t, input, output, incrementor, saveXis, null); } // If outputAlphabet is non-null, this will create a LabelVector // for each position in the output sequence indicating the // probability distribution over possible outputs at that time // index public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) { this.t = t; if (false && logger.isLoggable (Level.FINE)) { logger.fine ("Starting Lattice"); logger.fine ("Input: "); for (int ip = 0; ip < input.size(); ip++) logger.fine (" " + input.get(ip)); logger.fine ("\nOutput: "); if (output == null) logger.fine ("null"); else for (int op = 0; op < output.size(); op++) logger.fine (" " + output.get(op)); logger.fine ("\n"); } // Initialize some structures this.input = input; this.output = output; // xxx Not very efficient when the lattice is actually sparse, // especially when the number of states is large and the // sequence is long. latticeLength = input.size()+1; int numStates = t.numStates(); nodes = new LatticeNode[latticeLength][numStates]; // xxx Yipes, this could get big; something sparse might be better? gammas = new double[latticeLength][numStates]; if (saveXis) xis = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int i = 0; i < numStates; i++) { for (int ip = 0; ip < latticeLength; ip++) gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT; if (saveXis) for (int j = 0; j < numStates; j++) for (int ip = 0; ip < latticeLength; ip++) xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT; } // Forward pass logger.fine ("Starting Foward pass"); boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialWeight = t.getState(i).getInitialWeight(); //System.out.println ("Forward pass initialWeight = "+initialWeight); if (initialWeight < Transducer.IMPOSSIBLE_WEIGHT) { getLatticeNode(0, i).alpha = initialWeight; //System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha); atLeastOneInitialState = true; } } if (atLeastOneInitialState == false) logger.warning ("There are no starting states!"); // CPAL - a sorted list for our beam experiments NBestSlist[] slists = new NBestSlist[latticeLength]; // CPAL - used for stats nstatesExpl = new double[latticeLength]; // CPAL - used to adapt beam if optimizer is getting confused // tctIter++; if(curIter == 0) { curBeamWidth = numStates; } else if(tctIter > 1 && curIter != 0) { //curBeamWidth = Math.min((int)Math.round(curAvgNstatesExpl*2),numStates); //System.out.println ("Doubling Minimum Beam Size to: "+curBeamWidth); curBeamWidth = beamWidth; } else { curBeamWidth = beamWidth; } // ************************************************************ for (int ip = 0; ip < latticeLength-1; ip++) { // CPAL - add this to construct the beam // *************************************************** // CPAL - sets up the sorted list slists[ip] = new NBestSlist(numStates); // CPAL - set the slists[ip].setKLMinE(curBeamWidth); slists[ip].setKLeps(KLeps); slists[ip].setRmin(Rmin); for(int i = 0 ; i< numStates ; i++){ if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) continue; //State s = t.getState(i); // CPAL - give the NB viterbi node the (Weight, position) NBForBackNode cnode = new NBForBackNode(nodes[ip][i].alpha, i); slists[ip].push(cnode); } // CPAL - unlike std. n-best beam we now filter the list based // on a KL divergence like measure // *************************************************** // use method which computes the cumulative log sum and // finds the point at which the sum is within KLeps int KLMaxPos=1; int RminPos=1; if(KLeps > 0) { KLMaxPos = slists[ip].getKLpos(); nstatesExpl[ip]=(double)KLMaxPos; } else if(KLeps == 0) { if(Rmin > 0) { RminPos = slists[ip].getTHRpos(); } else { slists[ip].setRmin(-Rmin); RminPos = slists[ip].getTHRposSTRAWMAN(); } nstatesExpl[ip]=(double)RminPos; } else { // Trick, negative values for KLeps mean use the max of KL an Rmin slists[ip].setKLeps(-KLeps); KLMaxPos = slists[ip].getKLpos(); //RminPos = slists[ip].getTHRpos(); if(Rmin > 0) { RminPos = slists[ip].getTHRpos(); } else { slists[ip].setRmin(-Rmin); RminPos = slists[ip].getTHRposSTRAWMAN(); } if(KLMaxPos > RminPos) { nstatesExpl[ip]=(double)KLMaxPos; } else { nstatesExpl[ip]=(double)RminPos; } } //System.out.println(nstatesExpl[ip] + " "); // CPAL - contemplating setting values to something else int tmppos; for (int i = (int) nstatesExpl[ip]+1; i < slists[ip].size(); i++) { tmppos = slists[ip].getPosByIndex(i); nodes[ip][tmppos].alpha = Transducer.IMPOSSIBLE_WEIGHT; nodes[ip][tmppos] = null; // Null is faster and seems to work the same } // - done contemplation //for (int i = 0; i < numStates; i++) { for(int jj=0 ; jj< nstatesExpl[ip]; jj++) { int i = slists[ip].getPosByIndex(jj); // CPAL - dont need this anymore // should be taken care of in the lists //if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) // xxx if we end up doing this a lot, // we could save a list of the non-null ones // continue; State s = t.getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); if (logger.isLoggable (Level.FINE)) logger.fine (" Starting Foward transition iteration from state " + s.getName() + " on input " + input.get(ip).toString() + " and output " + (output==null ? "(null)" : output.get(ip).toString())); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Forward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex()); destinationNode.output = iter.getOutput(); double transitionWeight = iter.getWeight(); if (logger.isLoggable (Level.FINE)) logger.fine ("transitionWeight="+transitionWeight +" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha +" destinationNode.alpha="+destinationNode.alpha); destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha, nodes[ip][i].alpha + transitionWeight); //System.out.println ("destinationNode.alpha <- "+destinationNode.alpha); } } } //System.out.println("Mean Nodes Explored: " + MatrixOps.mean(nstatesExpl)); curAvgNstatesExpl = MatrixOps.mean(nstatesExpl); // Calculate total cost of Lattice. This is the normalizer weight = Transducer.IMPOSSIBLE_WEIGHT; for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { // Note: actually we could sum at any ip index, // the choice of latticeLength-1 is arbitrary //System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha); //System.out.println ("Ending beta, state["+i+"] = "+t.getState(i).finalWeight); weight = Transducer.sumLogProb (weight, (nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight())); } // Weight is now an "unnormalized weight" of the entire Lattice //assert (weight >= 0) : "weight = "+weight; // If the sequence has -infinite weight, just return. // Usefully this avoids calling any incrementX methods. // It also relies on the fact that the gammas[][] and .alpha and .beta values // are already initialized to values that reflect -infinite weight // xxx Although perhaps not all (alphas,betas) exactly correctly reflecting? if (weight == Transducer.IMPOSSIBLE_WEIGHT) return; // Backward pass for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { State s = t.getState(i); nodes[latticeLength-1][i].beta = s.getFinalWeight(); gammas[latticeLength-1][i] = nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - weight; if (incrementor != null) { double p = Math.exp(gammas[latticeLength-1][i]); assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i]; incrementor.incrementFinalState(s, p); } } for (int ip = latticeLength-2; ip >= 0; ip--) { for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) // Note that skipping here based on alpha means that beta values won't // be correct, but since alpha is infinite anyway, it shouldn't matter. continue; State s = t.getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Backward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip+1][j]; if (destinationNode != null) { double transitionWeight = iter.getWeight(); assert (!Double.isNaN(transitionWeight)); // assert (transitionWeight >= 0); Not necessarily double oldBeta = nodes[ip][i].beta; assert (!Double.isNaN(nodes[ip][i].beta)); nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta, destinationNode.beta + transitionWeight); assert (!Double.isNaN(nodes[ip][i].beta)) : "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight) + " oldBeta="+oldBeta; double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight; if (saveXis) xis[ip][i][j] = xi; assert (!Double.isNaN(nodes[ip][i].alpha)); assert (!Double.isNaN(transitionWeight)); assert (!Double.isNaN(nodes[ip+1][j].beta)); assert (!Double.isNaN(weight)); if (incrementor != null || outputAlphabet != null) { double p = Math.exp(xi); assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+xi; if (incrementor != null) incrementor.incrementTransition(iter, p); if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false); assert (outputIndex >= 0); // xxx This assumes that "ip" == "op"! outputCounts[ip][outputIndex] += p; //System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p); } } } } gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - weight; } if(true){ // CPAL - check the normalization double checknorm = Transducer.IMPOSSIBLE_WEIGHT; for (int i = 0; i < numStates; i++) if (nodes[ip][i] != null) { // Note: actually we could sum at any ip index, // the choice of latticeLength-1 is arbitrary //System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha); //System.out.println ("Ending beta, state["+i+"] = "+t.getState(i).finalWeight); checknorm = Transducer.sumLogProb (checknorm, gammas[ip][i]); } // System.out.println ("Check Gamma, sum="+checknorm); // CPAL - done check of normalization // CPAL - normalize for (int i = 0; i < numStates; i++) if (nodes[ip][i] != null) { gammas[ip][i] = gammas[ip][i] - checknorm; } //System.out.println ("Check Gamma, sum="+checknorm); // CPAL - normalization } } if (incrementor != null) for (int i = 0; i < numStates; i++) { double p = Math.exp(gammas[0][i]); assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p)); incrementor.incrementInitialState(t.getState(i), p); } if (outputAlphabet != null) { labelings = new LabelVector[latticeLength]; for (int ip = latticeLength-2; ip >= 0; ip--) { assert (Math.abs(1.0-MatrixOps.sum (outputCounts[ip])) < 0.000001);; labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]); } } } public Sequence getInput() { return input; } // CPAL - a simple node holding a weight and position of the state private class NBForBackNode { double weight; int pos; NBForBackNode(double weight, int pos) { this.weight = weight; this.pos = pos; } } private class NBestSlist { ArrayList list = new ArrayList(); int MaxElements; int KLMinElements; int KLMaxPos; double KLeps; double Rmin; NBestSlist(int MaxElements) { this.MaxElements = MaxElements; } boolean setKLMinE(int KLMinElements){ this.KLMinElements = KLMinElements; return true; } int size() { return list.size(); } boolean empty() { return list.isEmpty(); } Object pop() { return list.remove(0); } int getPosByIndex(int ii){ NBForBackNode tn = (NBForBackNode)list.get(ii); return tn.pos; } double getWeightByIndex(int ii){ NBForBackNode tn = (NBForBackNode)list.get(ii); return tn.weight; } void setKLeps(double KLeps){ this.KLeps = KLeps; } void setRmin(double Rmin){ this.Rmin = Rmin; } int getTHRpos(){ NBForBackNode tn; double lc1, lc2; tn = (NBForBackNode)list.get(0); lc1 = tn.weight; tn = (NBForBackNode)list.get(list.size()-1); lc2 = tn.weight; double minc = lc1 - lc2; double mincTHR = minc - minc*Rmin; for(int i=1;i<list.size();i++){ tn = (NBForBackNode)list.get(i); lc1 = tn.weight - lc2; if(lc1 > mincTHR){ return i+1; } } return list.size(); } int getTHRposSTRAWMAN(){ NBForBackNode tn; double lc1, lc2; tn = (NBForBackNode)list.get(0); lc1 = tn.weight; double mincTHR = -lc1*Rmin; //double minc = lc1 - lc2; //double mincTHR = minc - minc*Rmin; for(int i=1;i<list.size();i++){ tn = (NBForBackNode)list.get(i); lc1 = -tn.weight; if(lc1 < mincTHR){ return i+1; } } return list.size(); } int getKLpos(){ //double KLeps = 0.1; double CSNLP[]; CSNLP = new double[MaxElements]; double worstc; NBForBackNode tn; tn = (NBForBackNode)list.get(list.size()-1); worstc = tn.weight; for(int i=0;i<list.size();i++){ tn = (NBForBackNode)list.get(i); // NOTE: sometimes we can have positive numbers ! double lc = tn.weight; //double lc = tn.weight-worstc; //if(lc >0){ // int asdf=1; //} if (i==0) { CSNLP[i] = lc; } else { CSNLP[i] = Transducer.sumLogProb(CSNLP[i-1], lc); } } // normalize for(int i=0;i<list.size();i++){ CSNLP[i]=CSNLP[i]-CSNLP[list.size()-1]; if(CSNLP[i] < KLeps){ KLMaxPos = i+1; if(KLMaxPos >= KLMinElements) { return KLMaxPos; } else if(list.size() >= KLMinElements){ return KLMinElements; } } } KLMaxPos = list.size(); return KLMaxPos; } ArrayList push(NBForBackNode vn) { double tc = vn.weight; boolean atEnd = true; for(int i=0;i<list.size();i++){ NBForBackNode tn = (NBForBackNode)list.get(i); double lc = tn.weight; if(tc < lc){ list.add(i,vn); atEnd = false; break; } } if(atEnd) { list.add(vn); } // CPAL - if the list is too big, // remove the first, largest weight element if(list.size()>MaxElements) { list.remove(MaxElements); } //double f = o.totalWeight[o.nextBestStateIndex]; //boolean atEnd = true; //for(int i=0; i<list.size(); i++){ // ASearchNode_NBest tempNode = (ASearchNode_NBest)list.get(i); // double f1 = tempNode.totalWeight[tempNode.nextBestStateIndex]; // if(f < f1) { // list.add(i, o); // atEnd = false; // break; // } //} //if(atEnd) list.add(o); return list; } } // CPAL - end NBestSlist // culotta: interface for constrained lattice /** Create constrained lattice such that all paths pass through the the labeling of <code> requiredSegment </code> as indicated by <code> constrainedSequence </code> @param inputSequence input sequence @param outputSequence output sequence @param requiredSegment segment of sequence that must be labelled @param constrainedSequence lattice must have labels of this sequence from <code> requiredSegment.start </code> to <code> requiredSegment.end </code> correctly */ SumLatticeBeam (Transducer t, Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) { this (t, inputSequence, outputSequence, (Transducer.Incrementor)null, null, makeConstraints(t, inputSequence, outputSequence, requiredSegment, constrainedSequence)); } private static int[] makeConstraints (Transducer t, Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) { if (constrainedSequence.size () != inputSequence.size ()) throw new IllegalArgumentException ("constrainedSequence.size [" + constrainedSequence.size () + "] != inputSequence.size [" + inputSequence.size () + "]"); // constraints tells the lattice which states must emit which // observations. positive values say all paths must pass through // this state index, negative values say all paths must _not_ // pass through this state index. 0 means we don't // care. initialize to 0. include 1 extra node for start state. int [] constraints = new int [constrainedSequence.size() + 1]; for (int c = 0; c < constraints.length; c++) constraints[c] = 0; for (int i=requiredSegment.getStart (); i <= requiredSegment.getEnd(); i++) { int si = t.stateIndexOfString ((String)constrainedSequence.get (i)); if (si == -1) logger.warning ("Could not find state " + constrainedSequence.get (i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data."); // throw new IllegalArgumentException ("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTags and InTags."); constraints[i+1] = si + 1; } // set additional negative constraint to ensure state after // segment is not a continue tag // xxx if segment length=1, this actually constrains the sequence // to B-tag (B-tag)', instead of the intended constraint of B-tag // (I-tag)' // the fix below is unsafe, but will have to do for now. // FIXED BELOW /* String endTag = (String) constrainedSequence.get (requiredSegment.getEnd ()); if (requiredSegment.getEnd()+2 < constraints.length) { if (requiredSegment.getStart() == requiredSegment.getEnd()) { // segment has length 1 if (endTag.startsWith ("B-")) { endTag = "I" + endTag.substring (1, endTag.length()); } else if (!(endTag.startsWith ("I-") || endTag.startsWith ("0"))) throw new IllegalArgumentException ("Constrained Lattice requires that states are tagged in B-I-O format."); } int statei = stateIndexOfString (endTag); if (statei == -1) // no I- tag for this B- tag statei = stateIndexOfString ((String)constrainedSequence.get (requiredSegment.getStart ())); constraints[requiredSegment.getEnd() + 2] = - (statei + 1); } */ if (requiredSegment.getEnd() + 2 < constraints.length) { // if String endTag = requiredSegment.getInTag().toString(); int statei = t.stateIndexOfString (endTag); if (statei == -1) throw new IllegalArgumentException ("Could not find state " + endTag + ". Check that state labels match startTags and InTags."); constraints[requiredSegment.getEnd() + 2] = - (statei + 1); } // printStates (); logger.fine ("Segment:\n" + requiredSegment.sequenceToString () + "\nconstrainedSequence:\n" + constrainedSequence + "\nConstraints:\n"); for (int i=0; i < constraints.length; i++) { logger.fine (constraints[i] + "\t"); } logger.fine (""); return constraints; } // culotta: constructor for constrained lattice /** Create a lattice that constrains its transitions such that the * <position,label> pairs in "constraints" are adhered * to. constraints is an array where each entry is the index of * the required label at that position. An entry of 0 means there * are no constraints on that <position, label>. Positive values * mean the path must pass through that state. Negative values * mean the path must _not_ pass through that state. NOTE - * constraints.length must be equal to output.size() + 1. A * lattice has one extra position for the initial * state. Generally, this should be unconstrained, since it does * not produce an observation. */ public SumLatticeBeam (Transducer t, Sequence input, Sequence output, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet, int [] constraints) { this.t = t; if (false && logger.isLoggable (Level.FINE)) { logger.fine ("Starting Lattice"); logger.fine ("Input: "); for (int ip = 0; ip < input.size(); ip++) logger.fine (" " + input.get(ip)); logger.fine ("\nOutput: "); if (output == null) logger.fine ("null"); else for (int op = 0; op < output.size(); op++) logger.fine (" " + output.get(op)); logger.fine ("\n"); } // Initialize some structures this.input = input; this.output = output; // xxx Not very efficient when the lattice is actually sparse, // especially when the number of states is large and the // sequence is long. latticeLength = input.size()+1; int numStates = t.numStates(); nodes = new LatticeNode[latticeLength][numStates]; // xxx Yipes, this could get big; something sparse might be better? gammas = new double[latticeLength][numStates]; // xxx Move this to an ivar, so we can save it? But for what? // Commenting this out, because it's a memory hog and not used right now. // Uncomment and conditionalize under a flag if ever needed. -cas // double xis[][][] = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int i = 0; i < numStates; i++) { for (int ip = 0; ip < latticeLength; ip++) gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT; /* Commenting out xis -cas for (int j = 0; j < numStates; j++) for (int ip = 0; ip < latticeLength; ip++) xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT; */ } // Forward pass logger.fine ("Starting Constrained Foward pass"); // ensure that at least one state has initial weight less than Infinity // so we can start from there boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialWeight = t.getState(i).getInitialWeight(); //System.out.println ("Forward pass initialWeight = "+initialWeight); if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) { getLatticeNode(0, i).alpha = initialWeight; //System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha); atLeastOneInitialState = true; } } if (atLeastOneInitialState == false) logger.warning ("There are no starting states!"); for (int ip = 0; ip < latticeLength-1; ip++) for (int i = 0; i < numStates; i++) { logger.fine ("ip=" + ip+", i=" + i); // check if this node is possible at this <position, // label>. if not, skip it. if (constraints[ip] > 0) { // must be in state indexed by constraints[ip] - 1 if (constraints[ip]-1 != i) { logger.fine ("Current state does not match positive constraint. position="+ip+", constraint="+(constraints[ip]-1)+", currState="+i); continue; } } else if (constraints[ip] < 0) { // must _not_ be in state indexed by constraints[ip] if (constraints[ip]+1 == -i) { logger.fine ("Current state does not match negative constraint. position="+ip+", constraint="+(constraints[ip]+1)+", currState="+i); continue; } } if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) { // xxx if we end up doing this a lot, // we could save a list of the non-null ones if (nodes[ip][i] == null) logger.fine ("nodes[ip][i] is NULL"); else if (nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) logger.fine ("nodes[ip][i].alpha is Inf"); logger.fine ("-INFINITE weight or NULL...skipping"); continue; } State s = t.getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); if (logger.isLoggable (Level.FINE)) logger.fine (" Starting Forward transition iteration from state " + s.getName() + " on input " + input.get(ip).toString() + " and output " + (output==null ? "(null)" : output.get(ip).toString())); while (iter.hasNext()) { State destination = iter.nextState(); boolean legalTransition = true; // check constraints to see if node at <ip,i> can transition to destination if (ip+1 < constraints.length && constraints[ip+1] > 0 && ((constraints[ip+1]-1) != destination.getIndex())) { logger.fine ("Destination state does not match positive constraint. Assigning -infinite weight. position="+(ip+1)+", constraint="+(constraints[ip+1]-1)+", source ="+i+", destination="+destination.getIndex()); legalTransition = false; } else if (((ip+1) < constraints.length) && constraints[ip+1] < 0 && (-(constraints[ip+1]+1) == destination.getIndex())) { logger.fine ("Destination state does not match negative constraint. Assigning -infinite weight. position="+(ip+1)+", constraint="+(constraints[ip+1]+1)+", destination="+destination.getIndex()); legalTransition = false; } if (logger.isLoggable (Level.FINE)) logger.fine ("Forward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex()); destinationNode.output = iter.getOutput(); double transitionWeight = iter.getWeight(); if (legalTransition) { //if (logger.isLoggable (Level.FINE)) logger.fine ("transitionWeight="+transitionWeight +" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha +" destinationNode.alpha="+destinationNode.alpha); destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha, nodes[ip][i].alpha + transitionWeight); //System.out.println ("destinationNode.alpha <- "+destinationNode.alpha); logger.fine ("Set alpha of latticeNode at ip = "+ (ip+1) + " stateIndex = " + destination.getIndex() + ", destinationNode.alpha = " + destinationNode.alpha); } else { // this is an illegal transition according to our // constraints, so set its prob to 0 . NO, alpha's are // unnormalized weights...set to Inf // // destinationNode.alpha = 0.0; // destinationNode.alpha = Transducer.IMPOSSIBLE_WEIGHT; logger.fine ("Illegal transition from state " + i + " to state " + destination.getIndex() + ". Setting alpha to Inf"); } } } // Calculate total weight of Lattice. This is the normalizer weight = Transducer.IMPOSSIBLE_WEIGHT; for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { // Note: actually we could sum at any ip index, // the choice of latticeLength-1 is arbitrary //System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha); //System.out.println ("Ending beta, state["+i+"] = "+t.getState(i).finalWeight); if (constraints[latticeLength-1] > 0 && i != constraints[latticeLength-1]-1) continue; if (constraints[latticeLength-1] < 0 && -i == constraints[latticeLength-1]+1) continue; logger.fine ("Summing final lattice weight. state="+i+", alpha="+nodes[latticeLength-1][i].alpha + ", final weight = "+t.getState(i).getFinalWeight()); weight = Transducer.sumLogProb (weight, (nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight())); } // Weight is now an "unnormalized weight" of the entire Lattice //assert (weight >= 0) : "weight = "+weight; // If the sequence has -infinite weight, just return. // Usefully this avoids calling any incrementX methods. // It also relies on the fact that the gammas[][] and .alpha and .beta values // are already initialized to values that reflect -infinite weight // xxx Although perhaps not all (alphas,betas) exactly correctly reflecting? if (weight == Transducer.IMPOSSIBLE_WEIGHT) return; // Backward pass for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { State s = t.getState(i); nodes[latticeLength-1][i].beta = s.getFinalWeight(); gammas[latticeLength-1][i] = nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - weight; if (incrementor != null) { double p = Math.exp(gammas[latticeLength-1][i]); assert (p >= 0 && p <= 1.0 && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i]; incrementor.incrementFinalState(s, p); } } for (int ip = latticeLength-2; ip >= 0; ip--) { for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) // Note that skipping here based on alpha means that beta values won't // be correct, but since alpha is infinite anyway, it shouldn't matter. continue; State s = t.getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Backward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip+1][j]; if (destinationNode != null) { double transitionWeight = iter.getWeight(); assert (!Double.isNaN(transitionWeight)); // assert (transitionWeight >= 0); Not necessarily double oldBeta = nodes[ip][i].beta; assert (!Double.isNaN(nodes[ip][i].beta)); nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta, destinationNode.beta + transitionWeight); assert (!Double.isNaN(nodes[ip][i].beta)) : "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight) + " oldBeta="+oldBeta; // xis[ip][i][j] = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight; assert (!Double.isNaN(nodes[ip][i].alpha)); assert (!Double.isNaN(transitionWeight)); assert (!Double.isNaN(nodes[ip+1][j].beta)); assert (!Double.isNaN(weight)); if (incrementor != null || outputAlphabet != null) { double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight; double p = Math.exp(xi); assert (p >= 0 && p <= 1.0 && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi; if (incrementor != null) incrementor.incrementTransition(iter, p); if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false); assert (outputIndex >= 0); // xxx This assumes that "ip" == "op"! outputCounts[ip][outputIndex] += p; //System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p); } } } } gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - weight; } } if (incrementor != null) for (int i = 0; i < numStates; i++) { double p = Math.exp(gammas[0][i]); assert (p >= 0.0 && p <= 1.0 && !Double.isNaN(p)); incrementor.incrementInitialState(t.getState(i), p); } if (outputAlphabet != null) { labelings = new LabelVector[latticeLength]; for (int ip = latticeLength-2; ip >= 0; ip--) { assert (Math.abs(1.0-MatrixOps.sum (outputCounts[ip])) < 0.000001);; labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]); } } } public double getTotalWeight () { assert (!Double.isNaN(weight)); return weight; } // No, this.weight is an "unnormalized weight" //public double getProbability () { return Math.exp (weight); } public double getGammaWeight (int inputPosition, State s) { return gammas[inputPosition][s.getIndex()]; } public double getGammaProbability (int inputPosition, State s) { return Math.exp (gammas[inputPosition][s.getIndex()]); } public double[][][] getXis() { return xis; } public double[][] getGammas () { return gammas; } public double getXiProbability (int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException ("xis were not saved."); int i = s1.getIndex (); int j = s2.getIndex (); return Math.exp (xis[ip][i][j]); } public double getXiWeight (int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException ("xis were not saved."); int i = s1.getIndex (); int j = s2.getIndex (); return xis[ip][i][j]; } public int length () { return latticeLength; } public double getAlpha (int ip, State s) { LatticeNode node = getLatticeNode (ip, s.getIndex ()); return node.alpha; } public double getBeta (int ip, State s) { LatticeNode node = getLatticeNode (ip, s.getIndex ()); return node.beta; } public LabelVector getLabelingAtPosition (int outputPosition) { if (labelings != null) return labelings[outputPosition]; return null; } public Transducer getTransducer () { return t; } // A container for some information about a particular input position and state private class LatticeNode { int inputPosition; // outputPosition not really needed until we deal with asymmetric epsilon. State state; Object output; double alpha = Transducer.IMPOSSIBLE_WEIGHT; double beta = Transducer.IMPOSSIBLE_WEIGHT; LatticeNode (int inputPosition, State state) { this.inputPosition = inputPosition; this.state = state; assert (this.alpha == Transducer.IMPOSSIBLE_WEIGHT); // xxx Remove this check } } public static class Factory extends SumLatticeFactory { int bw; public Factory (int beamWidth) { bw = beamWidth; } public SumLattice newSumLattice (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) { return new SumLatticeBeam (trans, input, output, incrementor, saveXis, outputAlphabet) {{ beamWidth = bw; }}; } } }