/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.constraints; import java.util.ArrayList; import gnu.trove.TIntIntHashMap; import cc.mallet.fst.semi_supervised.StateLabelMap; /** * A set of constraints on distributions over consecutive * labels conditioned an input features. * * This is to be used with GE, and penalizes the * KL divergence between model and target distributions. * * Multiple constraints are grouped together here * to make things more efficient. * * @author Gregory Druck */ public class TwoLabelKLGEConstraints extends TwoLabelGEConstraints { public TwoLabelKLGEConstraints() { super(); } private TwoLabelKLGEConstraints(ArrayList<TwoLabelGEConstraint> constraintsList, TIntIntHashMap constraintsMap, StateLabelMap map) { super(constraintsList,constraintsMap,map); } public GEConstraint copy() { return new TwoLabelKLGEConstraints(this.constraintsList, this.constraintsMap, this.map); } @Override public void addConstraint(int fi, double[][] target, double weight) { constraintsList.add(new TwoLabelKLGEConstraint(target,weight)); constraintsMap.put(fi, constraintsList.size()-1); } @Override public double getValue() { double value = 0.0; for (int fi : constraintsMap.keys()) { TwoLabelGEConstraint constraint = constraintsList.get(constraintsMap.get(fi)); if (constraint.count > 0.0) { double constraintValue = 0.0; for (int prevLi = 0; prevLi < map.getNumLabels(); prevLi++) { for (int currLi = 0; currLi < map.getNumLabels(); currLi++) { if (constraint.target[prevLi][currLi] > 0.0) { if (constraint.expectation[prevLi][currLi] == 0.0) { return Double.NEGATIVE_INFINITY; } else { // p*log(q) - p*log(p) // negative KL constraintValue += constraint.target[prevLi][currLi] * ( Math.log(constraint.expectation[prevLi][currLi]/constraint.count) - Math.log(constraint.target[prevLi][currLi])); } } } } assert(!Double.isNaN(constraintValue) && !Double.isInfinite(constraintValue)); value += constraintValue * constraint.weight; } } return value; } protected class TwoLabelKLGEConstraint extends TwoLabelGEConstraint { public TwoLabelKLGEConstraint(double[][] target, double weight) { super(target,weight); } @Override public double getValue(int liPrev, int liCurr) { assert(this.count != 0); if (this.target[liPrev][liCurr] == 0 && this.expectation[liPrev][liCurr] == 0) { return 0; } return this.weight * (this.target[liPrev][liCurr] / ( this.expectation[liPrev][liCurr] )); } } }