/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised; import java.util.ArrayList; import cc.mallet.fst.CRF; import cc.mallet.fst.Transducer; import cc.mallet.fst.semi_supervised.constraints.GEConstraint; import cc.mallet.types.FeatureVector; import cc.mallet.types.FeatureVectorSequence; import cc.mallet.types.LogNumber; /** * Runs the dynamic programming algorithm of [Mann and McCallum 08] for * computing the gradient of a Generalized Expectation constraint that * considers a single label of a linear chain CRF. * * See: * "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields" * Gideon Mann and Andrew McCallum * ACL 2008 * * gdruck NOTE: This new version of GE Lattice that computes the gradient * for all constraints simultaneously! * * @author Gregory Druck * @author Gaurav Chandalia * @author Gideon Mann */ public class GELattice { // input length + 1 protected int latticeLength; // the model protected Transducer transducer; // number of states in the FST protected int numStates; // dynamic programming lattice protected LatticeNode[][] lattice; // cache of dot produce between violation and // constraint features protected LogNumber[][][] dotCache; /** * @param fvs Input FeatureVectorSequence * @param gammas Marginals over single states * @param xis Marginals over pairs of states * @param transducer Transducer * @param reverseTrans Source state indices for each destination state * @param reverseTransIndices Transition indices for each destination state * @param gradient Gradient to increment * @param constraints List of constraints * @param check Whether to run the debugging test to verify correctness (will be much slower if true) */ public GELattice( FeatureVectorSequence fvs, double[][] gammas, double[][][] xis, Transducer transducer, int[][] reverseTrans, int[][] reverseTransIndices, CRF.Factors gradient, ArrayList<GEConstraint> constraints, boolean check) { assert(gradient != null); latticeLength = fvs.size() + 1; this.transducer = transducer; numStates = transducer.numStates(); // lattice lattice = new LatticeNode[latticeLength][numStates]; for (int ip = 0; ip < latticeLength; ++ip) { for (int a = 0; a < numStates; ++a) { lattice[ip][a] = new LatticeNode(); } } dotCache = new LogNumber[latticeLength][numStates][numStates]; // TODO maybe this should be cached? // Separate lists for constraints that look at one vs two states. ArrayList<GEConstraint> constraints1 = new ArrayList<GEConstraint>(); ArrayList<GEConstraint> constraints2 = new ArrayList<GEConstraint>(); for (GEConstraint constraint : constraints) { if (constraint.isOneStateConstraint()) { constraints1.add(constraint); } else { constraints2.add(constraint); } } CRF crf = (CRF)transducer; double dotEx = this.runForward(crf, constraints1, constraints2, gammas, xis, reverseTrans, fvs); this.runBackward(crf, gammas, xis, reverseTrans, reverseTransIndices, fvs, dotEx, gradient); //check(constraints,gammas,xis,fvs); } /** * Run forward pass of dynamic programming algorithm * * @param crf CRF * @param constraints1 Constraints that consider one state. * @param constraints2 Constraints that consider two states. * @param gammas Marginals over single states * @param xis Marginals over pairs of states * @param reverseTrans Source state indices for each destination state * @param fvs Input FeatureVectorSequence * @return */ private double runForward(CRF crf, ArrayList<GEConstraint> constraints1, ArrayList<GEConstraint> constraints2, double[][] gammas, double[][][] xis, int[][] reverseTrans, FeatureVectorSequence fvs) { double dotEx = 0; LogNumber[] oneStateValueCache = new LogNumber[numStates]; LogNumber nuAlpha = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); for (int ip = 0; ip < latticeLength-1; ++ip) { FeatureVector fv = fvs.get(ip); // speed things up by giving the constraints an // opportunity to cache, for example, which // constrained input features appear in this // FeatureVector for (GEConstraint constraint : constraints1) { constraint.preProcess(fv); } for (GEConstraint constraint : constraints2) { constraint.preProcess(fv); } boolean[] oneStateValComputed = new boolean[numStates]; for (int prev = 0; prev < numStates; prev++) { nuAlpha.set(Transducer.IMPOSSIBLE_WEIGHT,true); if (ip != 0) { int[] prevPrevs = reverseTrans[prev]; // calculate only once: \sum_y_{i-1} w_a(y_{i-1},y_i) for (int ppi = 0; ppi < prevPrevs.length; ppi++) { nuAlpha.plusEquals(lattice[ip-1][prevPrevs[ppi]].alpha[prev]); } } assert (!Double.isNaN(nuAlpha.logVal)); CRF.State prevState = (CRF.State)crf.getState(prev); LatticeNode node = lattice[ip][prev]; double[] xi = xis[ip][prev]; double gamma = gammas[ip][prev]; for (int ci = 0; ci < prevState.numDestinations(); ci++) { int curr = prevState.getDestinationState(ci).getIndex(); double dot = 0; for (GEConstraint constraint : constraints2) { dot += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr); } // avoid recomputing one-state constraint features #labels times if (!oneStateValComputed[curr]) { double osVal = 0; for (GEConstraint constraint : constraints1) { osVal += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr); } if (osVal < 0) { dotEx += Math.exp(gammas[ip+1][curr]) * osVal; oneStateValueCache[curr] = new LogNumber(Math.log(-osVal),false); } else if (osVal > 0) { dotEx += Math.exp(gammas[ip+1][curr]) * osVal; oneStateValueCache[curr] = new LogNumber(Math.log(osVal),true); } else { oneStateValueCache[curr] = null; } oneStateValComputed[curr] = true; } // combine the one and two state constraint feature values if (dot == 0 && oneStateValueCache[curr] == null) { dotCache[ip][prev][curr] = null; } else if (dot == 0 && oneStateValueCache[curr] != null) { dotCache[ip][prev][curr] = oneStateValueCache[curr]; } else { dotEx += Math.exp(xi[curr]) * dot; if (dot < 0) { dotCache[ip][prev][curr] = new LogNumber(Math.log(-dot),false); } else { dotCache[ip][prev][curr] = new LogNumber(Math.log(dot),true); } if (oneStateValueCache[curr] != null) { dotCache[ip][prev][curr].plusEquals(oneStateValueCache[curr]); } } // update the dynamic programming table if (dotCache[ip][prev][curr] != null) { temp.set(xi[curr],true); temp.timesEquals(dotCache[ip][prev][curr]); node.alpha[curr].plusEquals(temp); } if (gamma == Transducer.IMPOSSIBLE_WEIGHT) { node.alpha[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); } else { temp.set(xi[curr] - gamma,true); temp.timesEquals(nuAlpha); node.alpha[curr].plusEquals(temp); } assert (!Double.isNaN(node.alpha[curr].logVal)) : "xi: " + xi[curr] + ", gamma: " + gamma + ", constraint feature: " + dotCache[ip][prev][curr] + ", nuApha: " + nuAlpha + " dot: " + dot; } } } return dotEx; } /** * Run backward pass of dynamic programming algorithm * * @param crf CRF * @param gammas Marginals over single states * @param xis Marginals over pairs of states * @param reverseTrans Source state indices for each destination state * @param reverseTransIndices Transition indices for each destination state * @param fvs Input FeatureVectorSequence * @param dotEx Expectation of constraint features dot violation terms * @param gradient Gradient to increment * @return */ private void runBackward(CRF crf, double[][] gammas, double[][][] xis, int[][] reverseTrans, int[][] reverseTransIndices, FeatureVectorSequence fvs, double dotEx, CRF.Factors gradient) { LogNumber nuBeta = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); LogNumber dot = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); LogNumber temp2 = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); LogNumber nextDot; for (int ip = latticeLength-2; ip >= 0; --ip) { for (int curr = 0; curr < numStates; ++curr) { nuBeta.set(Transducer.IMPOSSIBLE_WEIGHT,true); dot.set(Transducer.IMPOSSIBLE_WEIGHT,true); // calculate only once: \sum_y_{i+1} w_b(y_i,y+i) CRF.State currState = (CRF.State)crf.getState(curr); for (int ni = 0; ni < currState.numDestinations(); ni++){ int next= currState.getDestinationState(ni).getIndex(); nuBeta.plusEquals(lattice[ip+1][curr].beta[next]); assert(!Double.isNaN(nuBeta.logVal)); nextDot = dotCache[ip+1][curr][next]; if (nextDot != null) { double xi = xis[ip+1][curr][next]; temp.set(xi,true); temp.timesEquals(nextDot); dot.plusEquals(temp); } } double gamma = gammas[ip+1][curr]; int[] prevStates = reverseTrans[curr]; for (int pi = 0; pi < prevStates.length; pi++) { int prev = prevStates[pi]; CRF.State crfState = (CRF.State)crf.getState(prev); LatticeNode node = lattice[ip][prev]; double xi = xis[ip][prev][curr]; if (gamma == Transducer.IMPOSSIBLE_WEIGHT) { node.beta[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); } else { // constraint feature values cached in Forward pass temp.set(dot.logVal,dot.sign); temp.plusEquals(nuBeta); temp2.set(xi-gamma,true); temp.timesEquals(temp2); node.beta[curr].plusEquals(temp); } assert(!Double.isNaN(node.beta[curr].logVal)) : "xi: " + xi + ", gamma: " + gamma + ", xi: " + xi + ", log(indicatorFeat): " + dotCache[ip][curr]; // compute and update gradient! double transProb = Math.exp(xi); double covFirstTerm = node.alpha[curr].exp() + node.beta[curr].exp(); double contribution = (covFirstTerm - (transProb * dotEx)); int nwi = crfState.getWeightNames(reverseTransIndices[curr][pi]).length; int weightsIndex; for (int wi = 0; wi < nwi; wi++) { weightsIndex = ((CRF)transducer).getWeightsIndex(crfState.getWeightNames(reverseTransIndices[curr][pi])[wi]); gradient.weights[weightsIndex].plusEqualsSparse (fvs.get(ip), contribution); gradient.defaultWeights[weightsIndex] += contribution; } } } } } /** * Verifies the correctness of the lattice computations. */ public void check(ArrayList<GEConstraint> constraints, double[][] gammas, double[][][] xis, FeatureVectorSequence fvs) { // sum of marginal probabilities double ex1 = 0.0; for (int ip = 0; ip < latticeLength-1; ++ip) { for (int si1 = 0; si1 < numStates; si1++) { for (int si2 = 0; si2 < numStates; si2++) { double dot = 0; for (GEConstraint constraint : constraints) { dot += constraint.getCompositeConstraintFeatureValue(fvs.get(ip), ip, si1, si2); } double prob = Math.exp(xis[ip][si1][si2]); ex1 += prob * dot; } } } double ex2 = 0.0; for (int ip = 0; ip < latticeLength-1; ++ip) { double ex3 = 0.0; for (int s1 = 0; s1 < numStates; ++s1) { LatticeNode node = lattice[ip][s1]; for (int s2 = 0; s2 < numStates; ++s2) { ex3 += node.alpha[s2].exp() + node.beta[s2].exp(); } } // should be equal to marginal prob. assert(ex1 - ex3 < 1e-6) :ex1 + " " + ex3; ex2 += ex3; } ex2 = ex2 / (latticeLength - 1); // should be equal to marginal prob. assert(ex1 - ex2 < 1e-6) : ex1 + " " + ex2; } public LogNumber getAlpha(int ip, int s1, int s2) { return lattice[ip][s1].alpha[s2]; } public LogNumber getBeta(int ip, int s1, int s2) { return lattice[ip][s1].beta[s2]; } /** * Contains forward-backward vectors correspoding to an input position and a * state index. */ protected class LatticeNode { // ip -> input position, a vector of doubles since for each node we need to // keep track of the alpha, beta values of state@(ip+1) protected LogNumber[] alpha; protected LogNumber[] beta; public LatticeNode() { alpha = new LogNumber[numStates]; beta = new LogNumber[numStates]; for (int si = 0; si < numStates; ++si) { alpha[si] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); beta[si] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true); } } } }