/* Copyright (C) 2009 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised;
import cc.mallet.fst.Transducer;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.util.Maths;
/**
* Runs subsequence constrained forward-backward to compute the entropy of label
* sequences. <p>
*
* Reference:
* Gideon Mann, Andrew McCallum
* "Efficient Computation of Entropy Gradient for Semi-Supervised Conditional Random Fields"
* HLT/NAACL 2007
*
* @author Gideon Mann
* @author Gaurav Chandalia
* @author Gregory Druck
*/
public class EntropyLattice {
// input_sequence_size + 1
protected int latticeLength;
// input_sequence_size
protected int inputLength;
// the model
protected Transducer transducer;
// number of states in the lattice (or the model's finite state machine)
protected int numStates;
// ip: input position, each node has a forward and backward factor used in the
// forward-backward algorithm, indexed by ip, state@ip (state index / si)
protected LatticeNode[][] nodes;
// subsequence constrained (forward) entropy
protected double entropy;
/**
* Runs constrained forward-backward. <p>
*
* If <tt>incrementor</tt> is null then do not update expectations due to
* these computations. <p>
*
* The contribution of entropy to the expectations is multiplies by the
* scaling factor.
*/
public EntropyLattice(FeatureVectorSequence fvs, double[][] gammas,
double[][][] xis, Transducer transducer,
Transducer.Incrementor incrementor,
double scalingFactor) {
inputLength = fvs.size();
latticeLength = inputLength + 1;
this.transducer = transducer;
numStates = transducer.numStates();
nodes = new LatticeNode[latticeLength][numStates];
// run forward-backward and compute the entropy
entropy = this.forwardLattice(gammas, xis);
double backwardEntropy = this.backwardLattice(gammas, xis);
assert(Maths.almostEquals(entropy, backwardEntropy)) : entropy + " " + backwardEntropy;
if (incrementor != null) {
// add the entropy to expectations
this.updateCounts(fvs, gammas, xis, scalingFactor, incrementor);
}
}
public double getEntropy() {
return entropy;
}
/**
* Computes the forward entropies (H^alpha).
*/
public double forwardLattice(double[][] gammas, double[][][] xis) {
// initialize entropy of start states to 0
for (int a = 0; a < numStates; ++a) {
this.getLatticeNode(0, a).alpha = 0;
}
for (int ip = 1; ip < latticeLength; ++ip) {
for (int a = 0; a < numStates; ++a) {
// position ip-1 in input sequence, state a
LatticeNode node = this.getLatticeNode(ip, a);
double gamma = gammas[ip][a];
if (gamma > Transducer.IMPOSSIBLE_WEIGHT) {
for (int b = 0; b < numStates; ++b) {
// position ip in input sequence, state a, coming from state b
double xi = xis[ip-1][b][a];
if (xi > Transducer.IMPOSSIBLE_WEIGHT) {
// p(y_{ip-1}=b|y_{ip}=a)
double condProb = Math.exp(xi) / Math.exp(gamma);
node.alpha += condProb * ((xi - gamma) +
this.getLatticeNode(ip-1, b).alpha);
}
}
}
}
}
double entropy = 0.0;
for (int a = 0; a < numStates; ++a) {
double gamma = gammas[inputLength][a];
double gammaProb = Math.exp(gamma);
if (gamma > Transducer.IMPOSSIBLE_WEIGHT) {
entropy += gammaProb * gamma;
entropy += gammaProb * this.getLatticeNode(inputLength, a).alpha;
}
}
return entropy;
}
/**
* Computes the backward entropies (H^beta).
*/
public double backwardLattice(double[][] gammas, double[][][] xis) {
// initialize entropy of end states to 0
for (int a = 0; a < numStates; ++a) {
this.getLatticeNode(inputLength, a).beta = 0;
}
for (int ip = inputLength; ip >= 0; --ip) {
for (int a = 0; a < numStates; ++a) {
// position ip-1 in input sequence, state a
LatticeNode node = this.getLatticeNode(ip, a);
double gamma = gammas[ip][a];
if (gamma > Transducer.IMPOSSIBLE_WEIGHT) {
for (int b = 0; b < numStates; ++b) {
// position ip in input sequence, state a
double xi = xis[ip][a][b];
if (xi > Transducer.IMPOSSIBLE_WEIGHT) {
// p(y_{ip}=b|y_{ip-1}=a)
double condProb = Math.exp(xi) / Math.exp(gamma);
node.beta += condProb * ((xi - gamma) +
this.getLatticeNode(ip+1, b).beta);
}
}
}
}
}
double entropy = 0.0;
for (int a = 0; a < numStates; ++a) {
double gamma = gammas[0][a];
double gammaProb = Math.exp(gamma);
if (gamma > Transducer.IMPOSSIBLE_WEIGHT) {
entropy += gammaProb * gamma;
entropy += gammaProb * this.getLatticeNode(0, a).beta;
}
}
return entropy;
}
/**
* Updates the expectations due to the entropy. <p>
*/
private void updateCounts(FeatureVectorSequence fvs, double[][] gammas,
double[][][] xis, double scalingFactor, Transducer.Incrementor incrementor) {
for (int ip = 0; ip < inputLength; ++ip) {
for (int a = 0 ; a < numStates; ++a) {
if (nodes[ip][a] == null) {
continue;
}
Transducer.State sourceState = transducer.getState(a);
Transducer.TransitionIterator iter = sourceState.transitionIterator(fvs, ip, null, ip);
while (iter.hasNext()) {
int b = iter.next().getIndex();
double xi = xis[ip][a][b];
if (xi == Transducer.IMPOSSIBLE_WEIGHT) {
continue;
}
double xiProb = Math.exp(xi);
// This is obtained after substituting and re-arranging the equation
// at the end of the third page of the paper into the equation of
// d/d_theta -H(Y|x) at the end of the second page.
// \sum_(y_i,y_{i+1})
// f_k(y_i,y_{i+1},x) p(y_i, y_{i+1}) *
// (log p(y_i,y_{i+1}) + H^a(Y_{1..(i-1)},y_i) +
// H^b(Y_{(i+2)..T}|y_{i+1}))
double constrEntropy = xiProb * (xi + nodes[ip][a].alpha + nodes[ip+1][b].beta);
assert(constrEntropy <= 0) : "Negative entropy should be negative! " + constrEntropy;
// full covariance, (note: it could be positive *or* negative)
double covContribution = constrEntropy - xiProb * entropy;
assert(!Double.isNaN(covContribution))
: "xi: " + xi + ", nodes[" + ip + "][" + a + "].alpha: " +
nodes[ip][a].alpha + ", nodes[" + (ip+1) + "][" + b +
"].beta: " + nodes[ip+1][b].beta;
incrementor.incrementTransition(iter, covContribution * scalingFactor);
}
}
}
}
public LatticeNode getLatticeNode(int ip, int si) {
if (nodes[ip][si] == null) {
nodes[ip][si] = new LatticeNode(ip, transducer.getState(si));
}
return nodes[ip][si];
}
/**
* Contains alpha, beta values at a particular input position and state pair.
*/
public class LatticeNode {
public int ip;
public Transducer.State state;
public double alpha;
public double beta;
LatticeNode(int ip, Transducer.State state) {
this.ip = ip;
this.state = state;
this.alpha = 0.0;
this.beta = 0.0;
}
}
}