package cc.mallet.fst; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.logging.Level; import java.util.logging.Logger; import cc.mallet.fst.Transducer.State; import cc.mallet.fst.Transducer.TransitionIterator; import cc.mallet.types.LabelAlphabet; import cc.mallet.types.LabelVector; import cc.mallet.types.Sequence; import cc.mallet.util.MalletLogger; public class SumLatticeScaling implements SumLattice { private static Logger logger = MalletLogger .getLogger(SumLatticeScaling.class.getName()); protected static boolean saveXis = false; // "ip" == "input position", "op" == "output position", "i" == "state index" @SuppressWarnings("unchecked") Sequence input, output; Transducer t; double totalWeight; LatticeNode[][] nodes; // indexed by ip,i double[] alphaLogScaling, betaLogScaling; double zLogScaling; int latticeLength; double[][] gammas; // indexed by ip,i double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true; // Ensure that instances cannot easily be created by a zero arg constructor. protected SumLatticeScaling() { } protected LatticeNode getLatticeNode(int ip, int stateIndex) { if (nodes[ip][stateIndex] == null) nodes[ip][stateIndex] = new LatticeNode(ip, t.getState(stateIndex)); return nodes[ip][stateIndex]; } @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input) { this(trans, input, null, (Transducer.Incrementor) null, saveXis, null); } @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, boolean saveXis) { this(trans, input, null, (Transducer.Incrementor) null, saveXis, null); } @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, Transducer.Incrementor incrementor) { this(trans, input, null, incrementor, saveXis, null); } @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, Sequence output) { this(trans, input, output, (Transducer.Incrementor) null, saveXis, null); } // You may pass null for output, meaning that the lattice // is not constrained to match the output @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor) { this(trans, input, output, incrementor, saveXis, null); } @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet) { this(trans, input, output, incrementor, saveXis, outputAlphabet); } // You may pass null for output, meaning that the lattice // is not constrained to match the output @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis) { this(trans, input, output, incrementor, saveXis, null); } @SuppressWarnings("unchecked") public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) { assert (output == null || input.size() == output.size()); // Initialize some structures this.t = trans; this.input = input; this.output = output; latticeLength = input.size() + 1; int numStates = t.numStates(); nodes = new LatticeNode[latticeLength][numStates]; alphaLogScaling = new double[latticeLength]; betaLogScaling = new double[latticeLength]; gammas = new double[latticeLength][numStates]; if (saveXis) xis = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int ip = 0; ip < latticeLength; ip++) { alphaLogScaling[ip] = 0.0; betaLogScaling[ip] = 0.0; for (int i = 0; i < numStates; i++) { gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT; if (saveXis) for (int j = 0; j < numStates; j++) xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT; } } // Forward pass logger.fine("Starting Foward pass"); boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialWeight = t.getState(i).getInitialWeight(); if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) { getLatticeNode(0, i).alpha = Math.exp(initialWeight); atLeastOneInitialState = true; } } rescaleAlphas(0); if (atLeastOneInitialState == false) logger.warning("There are no starting states!"); for (int ip = 0; ip < latticeLength - 1; ip++) { for (int i = 0; i < numStates; i++) { if (isInvalidNode(ip, i)) continue; State s = t.getState(i); TransitionIterator iter = s.transitionIterator(input, ip, output, ip); while (iter.hasNext()) { State destination = iter.next(); LatticeNode destinationNode = getLatticeNode(ip + 1, destination.getIndex()); if (Double.isNaN(destinationNode.alpha)) destinationNode.alpha = 0; destinationNode.output = iter.getOutput(); double transitionWeight = iter.getWeight(); destinationNode.alpha += nodes[ip][i].alpha * Math.exp(transitionWeight); } } // re-scale alphas to so that \sum_i \alpha[ip][i] = 1 rescaleAlphas(ip + 1); } // Calculate total weight of Lattice. This is the normalizer double Z = Double.NaN; for (int i = 0; i < numStates; i++) if (nodes[latticeLength - 1][i] != null) { if (Double.isNaN(Z)) Z = 0; Z += nodes[latticeLength - 1][i].alpha * Math.exp(t.getState(i).getFinalWeight()); } zLogScaling = alphaLogScaling[latticeLength - 1]; if (Double.isNaN(Z)) { totalWeight = Transducer.IMPOSSIBLE_WEIGHT; return; } else totalWeight = Math.log(Z) + zLogScaling; // Backward pass for (int i = 0; i < numStates; i++) if (nodes[latticeLength - 1][i] != null) { State s = t.getState(i); nodes[latticeLength - 1][i].beta = Math.exp(s.getFinalWeight()); double gamma = nodes[latticeLength - 1][i].alpha * nodes[latticeLength - 1][i].beta / Z; gammas[latticeLength - 1][i] = Math.log(gamma); if (incrementor != null) { double p = gamma; assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p + ", gamma=" + gammas[latticeLength - 1][i]; incrementor.incrementFinalState(s, p); } } rescaleBetas(latticeLength - 1); for (int ip = latticeLength - 2; ip >= 0; ip--) { for (int i = 0; i < numStates; i++) { if (isInvalidNode(ip, i)) continue; State s = t.getState(i); TransitionIterator iter = s.transitionIterator(input, ip, output, ip); double logScaling = alphaLogScaling[ip] + betaLogScaling[ip + 1] - zLogScaling; double pscaling = Math.exp(logScaling); while (iter.hasNext()) { State destination = iter.next(); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip + 1][j]; if (destinationNode != null) { double transitionWeight = iter.getWeight(); if (Double.isNaN(nodes[ip][i].beta)) nodes[ip][i].beta = 0; double transitionProb = Math.exp(transitionWeight); nodes[ip][i].beta += destinationNode.beta * transitionProb; double xi = nodes[ip][i].alpha * transitionProb * nodes[ip + 1][j].beta / Z; if (saveXis) xis[ip][i][j] = Math.log(xi) + logScaling; if (incrementor != null || outputAlphabet != null) { double p = xi * pscaling; assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p + ", xis[" + ip + "][" + i + "][" + j + "]=" + xi; if (incrementor != null) incrementor.incrementTransition(iter, p); if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex( iter.getOutput(), false); assert (outputIndex >= 0); outputCounts[ip][outputIndex] += p; } } } } gammas[ip][i] = Math.log(nodes[ip][i].alpha * nodes[ip][i].beta / Z) + logScaling; } // re-scale betas so that they are normalized rescaleBetas(ip); } if (incrementor != null) for (int i = 0; i < numStates; i++) { double p = Math.exp(gammas[0][i]); assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p; incrementor.incrementInitialState(t.getState(i), p); } } private boolean isInvalidNode(int ip, int i) { return nodes[ip][i] == null || Double.isNaN(nodes[ip][i].alpha); } private void rescaleAlphas(int ip) { double sumAlpha = 0; for (int i = 0; i < t.numStates(); i++) { if (!isInvalidNode(ip, i)) sumAlpha += nodes[ip][i].alpha; } assert sumAlpha > 0 : "Invalid sum over alphas for ip=" + ip; alphaLogScaling[ip] = Math.log(sumAlpha) + (ip == 0 ? 0 : alphaLogScaling[ip - 1]); for (int i = 0; i < t.numStates(); i++) { if (!isInvalidNode(ip, i)) nodes[ip][i].alpha /= sumAlpha; } } private void rescaleBetas(int ip) { double sumBeta = 0; for (int i = 0; i < t.numStates(); i++) { if (!isInvalidNode(ip, i)) sumBeta += nodes[ip][i].beta; } assert sumBeta > 0 : "Invalid sum over betas for ip=" + ip; betaLogScaling[ip] = Math.log(sumBeta) + (ip == latticeLength - 1 ? 0 : betaLogScaling[ip + 1]); for (int i = 0; i < t.numStates(); i++) { if (!isInvalidNode(ip, i)) nodes[ip][i].beta /= sumBeta; } } public double[][][] getXis() { return xis; } public double[][] getGammas() { return gammas; } public double getTotalWeight() { return totalWeight; } public double getGammaWeight(int inputPosition, State s) { return gammas[inputPosition][s.getIndex()]; } public double getGammaWeight(int inputPosition, int stateIndex) { return gammas[inputPosition][stateIndex]; } public double getGammaProbability(int inputPosition, State s) { return Math.exp(gammas[inputPosition][s.getIndex()]); } public double getGammaProbability(int inputPosition, int stateIndex) { return getGammaProbability(inputPosition, t.getState(stateIndex)); } public double getXiProbability(int ip, State s1, State s2) { return Math.exp(getXiWeight(ip, s1, s2)); } public double getXiWeight(int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException("xis were not saved."); int i = s1.getIndex(); int j = s2.getIndex(); return xis[ip][i][j]; } public int length() { return latticeLength; } public double getAlpha(int ip, State s) { LatticeNode node = getLatticeNode(ip, s.getIndex()); return node.alpha * Math.exp(alphaLogScaling[ip]); } public double getBeta(int ip, State s) { LatticeNode node = getLatticeNode(ip, s.getIndex()); return node.beta * Math.exp(betaLogScaling[ip]); } public LabelVector getLabelingAtPosition(int outputPosition) { throw new RuntimeException("Not implemented for SumLatticeScaling!"); } public Sequence getInput() { return input; } public Transducer getTransducer() { return t; } protected class LatticeNode { int inputPosition; State state; Object output; double alpha = Double.NaN; double beta = Double.NaN; LatticeNode(int inputPosition, State state) { this.inputPosition = inputPosition; this.state = state; } } public static class Factory extends SumLatticeFactory implements Serializable { @SuppressWarnings("unchecked") public SumLattice newSumLattice(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) { return new SumLatticeScaling(trans, input, output, incrementor, saveXis, outputAlphabet); } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { @SuppressWarnings("unused") int version = in.readInt(); } } }