/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.constraints.ge;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Maths;
/**
* Expectation constraint for use with GE.
* Penalizes KL divergence from target distribution.
*
* Multiple constraints are grouped together here
* to make things more efficient.
*
* @author Gregory Druck
*/
public class MaxEntKLFLGEConstraints extends MaxEntFLGEConstraints {
public MaxEntKLFLGEConstraints(int numFeatures, int numLabels, boolean useValues) {
super(numFeatures, numLabels, useValues);
}
public double getValue() {
double value = 0.0;
for (int fi : constraints.keys()) {
MaxEntFLGEConstraint constraint = constraints.get(fi);
if (constraint.count > 0.0) {
double constraintValue = 0.0;
for (int labelIndex = 0; labelIndex < numLabels; ++labelIndex) {
if (constraint.target[labelIndex] > 0.0) {
// if target is non-zero and expectation is 0, infinite penalty
if (constraint.expectation[labelIndex] == 0.0) {
return Double.NEGATIVE_INFINITY;
}
else {
// p*log(q) - p*log(p)
// negative KL
constraintValue += constraint.target[labelIndex] *
(Math.log(constraint.expectation[labelIndex]/constraint.count) -
Math.log(constraint.target[labelIndex]));
}
}
}
assert(!Double.isNaN(constraintValue) &&
!Double.isInfinite(constraintValue));
value += constraintValue * constraint.weight;
}
}
return value;
}
@Override
public void addConstraint(int fi, double[] ex, double weight) {
assert(Maths.almostEquals(MatrixOps.sum(ex),1));
constraints.put(fi,new MaxEntKLFLGEConstraint(ex,weight));
}
protected class MaxEntKLFLGEConstraint extends MaxEntFLGEConstraint {
public MaxEntKLFLGEConstraint(double[] target, double weight) {
super(target, weight);
}
@Override
public double getValue(int li) {
assert(this.count != 0);
if (this.target[li] == 0 && this.expectation[li] == 0) {
return 0;
}
return this.weight * (this.target[li] / this.expectation[li]);
}
}
}