/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.pr; import cc.mallet.fst.SumLattice; import cc.mallet.fst.Transducer; import cc.mallet.fst.Transducer.State; import cc.mallet.fst.Transducer.TransitionIterator; import cc.mallet.types.LabelVector; import cc.mallet.types.Sequence; /** * Lattice for M-step/M-projection in PR. * * @author Kedar Bellare * @author Gregory Druck */ public class SumLatticeKL implements SumLattice { // "ip" == "input position", "op" == "output position", "i" == "state index" Transducer t; double totalWeight; int latticeLength; double[][][] xis; Sequence input; protected SumLatticeKL() {} // If outputAlphabet is non-null, this will create a LabelVector // for each position in the output sequence indicating the // probability distribution over possible outputs at that time // index public SumLatticeKL(Transducer trans, Sequence input, double[] initProbs, double[] finalProbs, double[][][] xis, double[][][] cachedDots, Transducer.Incrementor incrementor) { assert (xis != null) : "Need transition probabilities"; // Initialize some structures this.t = trans; this.input = input; latticeLength = input.size() + 1; int numStates = t.numStates(); this.xis = xis; totalWeight = 0; // increment initial states for (int i = 0; i < numStates; i++) { if (t.getState(i).getInitialWeight() == Transducer.IMPOSSIBLE_WEIGHT) continue; if (initProbs != null) { totalWeight += initProbs[i] * t.getState(i).getInitialWeight(); if (incrementor != null) incrementor.incrementInitialState(t.getState(i), initProbs[i]); } } for (int ip = 0; ip < latticeLength - 1; ip++) for (int i = 0; i < numStates; i++) { State s = t.getState(i); TransitionIterator iter = s.transitionIterator(input, ip); while (iter.hasNext()) { State destination = iter.next(); double weight = iter.getWeight(); double p = xis[ip][i][destination.getIndex()]; totalWeight += p * weight; if (cachedDots != null) { cachedDots[ip][i][destination.getIndex()] = weight; } if (incrementor != null) { // this is used to gather "constraints", // so only probabilities under q are used incrementor.incrementTransition(iter, p); } } } for (int i = 0; i < numStates; i++) { if (t.getState(i).getFinalWeight() == Transducer.IMPOSSIBLE_WEIGHT) continue; if (finalProbs != null) { totalWeight += finalProbs[i] * t.getState(i).getFinalWeight(); if (incrementor != null) incrementor.incrementFinalState(t.getState(i), finalProbs[i]); } } assert (totalWeight > Transducer.IMPOSSIBLE_WEIGHT) : "Total weight=" + totalWeight; } public double[][][] getXis() { return xis; } public double[][] getGammas() { throw new UnsupportedOperationException("Not handled!"); } public double getTotalWeight() { assert (!Double.isNaN(totalWeight)); return totalWeight; } public double getGammaWeight(int inputPosition, State s) { throw new UnsupportedOperationException("Not handled!"); } public double getGammaWeight(int inputPosition, int stateIndex) { throw new UnsupportedOperationException("Not handled!"); } public double getGammaProbability(int inputPosition, State s) { throw new UnsupportedOperationException("Not handled!"); } public double getGammaProbability(int inputPosition, int stateIndex) { throw new UnsupportedOperationException("Not handled!"); } public double getXiProbability(int ip, State s1, State s2) { throw new UnsupportedOperationException("Not handled!"); } public double getXiWeight(int ip, State s1, State s2) { throw new UnsupportedOperationException("Not handled!"); } public int length() { return latticeLength; } public double getAlpha(int ip, State s) { throw new UnsupportedOperationException("Not handled!"); } public double getBeta(int ip, State s) { throw new UnsupportedOperationException("Not handled!"); } public LabelVector getLabelingAtPosition(int outputPosition) { return null; } public Transducer getTransducer() { return t; } public Sequence getInput() { return input; } }