/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.pr; import java.util.logging.Logger; import cc.mallet.fst.CRF; import cc.mallet.fst.SumLattice; import cc.mallet.fst.Transducer; import cc.mallet.fst.Transducer.State; import cc.mallet.types.LabelAlphabet; import cc.mallet.types.LabelVector; import cc.mallet.types.MatrixOps; import cc.mallet.types.Sequence; import cc.mallet.util.MalletLogger; /** * Lattice for E-step/I-projection in PR. * * @author Gregory Druck * @author Kedar Bellare */ public class SumLatticePR implements SumLattice { private static Logger logger = MalletLogger.getLogger(SumLatticePR.class.getName()); protected double totalWeight; protected int latticeLength; protected double[][] gammas; protected double[][][] xis; protected LabelVector labelings[]; protected Transducer transducer; protected LatticeNode[][] nodes; private Sequence input; public SumLatticePR(Transducer trans, int index, Sequence input, Sequence output, PRAuxiliaryModel auxModel, double[][][] cachedDots, boolean incrementConstraints, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet, boolean saveXis) { assert (output == null || input.size() == output.size()); // Initialize some structures this.input = input; this.transducer = trans; this.latticeLength = input.size() + 1; int numStates = transducer.numStates(); this.nodes = new LatticeNode[latticeLength][numStates]; this.gammas = new double[latticeLength][numStates]; if (saveXis) xis = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int i = 0; i < numStates; i++) { for (int ip = 0; ip < latticeLength; ip++) { gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT; } if (saveXis) { for (int j = 0; j < numStates; j++) { for (int ip = 0; ip < latticeLength; ip++) { xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT; } } } } // Forward pass boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialWeight = transducer.getState(i).getInitialWeight(); if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) { getLatticeNode(0, i).alpha = initialWeight; atLeastOneInitialState = true; } } if (atLeastOneInitialState == false) logger.warning("There are no starting states!"); for (int ip = 0; ip < latticeLength - 1; ip++) for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) { continue; } State s = transducer.getState(i); CachedDotTransitionIterator iter = new CachedDotTransitionIterator((CRF.State)s,input,ip, null,cachedDots[ip][i]); auxModel.preProcess(index,ip,input); while (iter.hasNext()) { State destination = iter.next(); LatticeNode destinationNode = getLatticeNode(ip + 1, destination.getIndex()); destinationNode.output = iter.getOutput(); double transitionWeight = iter.getWeight(); transitionWeight += auxModel.getWeight(index,ip,input,iter); destinationNode.alpha = Transducer.sumLogProb( destinationNode.alpha, nodes[ip][i].alpha + transitionWeight); } } totalWeight = Transducer.IMPOSSIBLE_WEIGHT; for (int i = 0; i < numStates; i++) { if (nodes[latticeLength-1][i] != null) { totalWeight = Transducer.sumLogProb(totalWeight, (nodes[latticeLength-1][i].alpha + transducer.getState(i).getFinalWeight())); } } if (totalWeight == Transducer.IMPOSSIBLE_WEIGHT) { return; } // Backward pass for (int i = 0; i < numStates; i++) if (nodes[latticeLength - 1][i] != null) { State s = transducer.getState(i); nodes[latticeLength - 1][i].beta = s.getFinalWeight(); gammas[latticeLength - 1][i] = nodes[latticeLength - 1][i].alpha + nodes[latticeLength - 1][i].beta - totalWeight; if (incrementor != null) { double p = Math.exp(gammas[latticeLength - 1][i]); assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p + ", gamma=" + gammas[latticeLength - 1][i]; incrementor.incrementFinalState(s, p); } } for (int ip = latticeLength - 2; ip >= 0; ip--) { for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) continue; State s = transducer.getState(i); CachedDotTransitionIterator iter = new CachedDotTransitionIterator((CRF.State)s,input,ip, null,cachedDots[ip][i]); auxModel.preProcess(index,ip,input); while (iter.hasNext()) { State destination = iter.next(); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip + 1][j]; if (destinationNode != null) { double transitionWeight = iter.getWeight(); transitionWeight += auxModel.getWeight(index,ip,input,iter); nodes[ip][i].beta = Transducer.sumLogProb( nodes[ip][i].beta, destinationNode.beta + transitionWeight); double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip + 1][j].beta - totalWeight; if (saveXis) xis[ip][i][j] = xi; if (incrementor != null || auxModel.numParameters() > 0 || outputAlphabet != null) { double p = Math.exp(xi); assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p + ", xis[" + ip + "][" + i + "][" + j + "]=" + xi; if (incrementor != null) { incrementor.incrementTransition(iter, p); } if (incrementConstraints) { // preprocess from above still applies auxModel.incrementTransition(index, ip, input, iter, p); } if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex(iter.getOutput(), false); assert (outputIndex >= 0); outputCounts[ip][outputIndex] += p; } } } } gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - totalWeight; } } if (incrementor != null) for (int i = 0; i < numStates; i++) { double p = Math.exp(gammas[0][i]); assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p; incrementor.incrementInitialState(transducer.getState(i), p); } if (outputAlphabet != null) { labelings = new LabelVector[latticeLength]; for (int ip = latticeLength - 2; ip >= 0; ip--) { assert (Math.abs(1.0 - MatrixOps.sum(outputCounts[ip])) < 0.000001); labelings[ip] = new LabelVector(outputAlphabet, outputCounts[ip]); } } } protected LatticeNode getLatticeNode(int ip, int stateIndex) { if (nodes[ip][stateIndex] == null) nodes[ip][stateIndex] = new LatticeNode(ip, transducer.getState(stateIndex)); return nodes[ip][stateIndex]; } public double[][][] getXis() { return xis; } public double[][] getGammas() { return gammas; } public double getTotalWeight() { assert (!Double.isNaN(totalWeight)); return totalWeight; } public double getGammaWeight(int inputPosition, State s) { return gammas[inputPosition][s.getIndex()]; } public double getGammaWeight(int inputPosition, int stateIndex) { return gammas[inputPosition][stateIndex]; } public double getGammaProbability(int inputPosition, State s) { return Math.exp(gammas[inputPosition][s.getIndex()]); } public double getGammaProbability(int inputPosition, int stateIndex) { return Math.exp(gammas[inputPosition][stateIndex]); } public double getXiProbability(int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException("xis were not saved."); int i = s1.getIndex(); int j = s2.getIndex(); return Math.exp(xis[ip][i][j]); } public double getXiWeight(int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException("xis were not saved."); int i = s1.getIndex(); int j = s2.getIndex(); return xis[ip][i][j]; } public int length() { return latticeLength; } public double getAlpha(int ip, State s) { LatticeNode node = getLatticeNode(ip, s.getIndex()); return node.alpha; } public double getBeta(int ip, State s) { LatticeNode node = getLatticeNode(ip, s.getIndex()); return node.beta; } public LabelVector getLabelingAtPosition(int outputPosition) { if (labelings != null) return labelings[outputPosition]; return null; } public Transducer getTransducer() { return transducer; } protected class LatticeNode { int inputPosition; State state; Object output; double alpha = Transducer.IMPOSSIBLE_WEIGHT; double beta = Transducer.IMPOSSIBLE_WEIGHT; LatticeNode(int inputPosition, State state) { this.inputPosition = inputPosition; this.state = state; assert (this.alpha == Transducer.IMPOSSIBLE_WEIGHT); } } public Sequence getInput() { return input; } }