/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.constraints; import gnu.trove.TIntObjectHashMap; import cc.mallet.fst.semi_supervised.StateLabelMap; /** * A set of constraints on distributions over consecutive * labels conditioned an input features. * * This is to be used with GE, and penalizes the * L_2^2 difference between model and target distributions. * * Multiple constraints are grouped together here * to make things more efficient. * * @author Gregory Druck */ public class OneLabelL2GEConstraints extends OneLabelGEConstraints { public OneLabelL2GEConstraints() { super(); } private OneLabelL2GEConstraints(TIntObjectHashMap<OneLabelGEConstraint> constraints, StateLabelMap map) { super(constraints,map); } public GEConstraint copy() { return new OneLabelL2GEConstraints(this.constraints, this.map); } @Override public void addConstraint(int fi, double[] target, double weight) { constraints.put(fi,new OneLabelGEL2Constraint(target,weight)); } @Override public double getValue() { double value = 0.0; for (int fi : constraints.keys()) { OneLabelGEConstraint constraint = constraints.get(fi); if ( constraint.count > 0.0) { // value due to current constraint double featureValue = 0.0; for (int labelIndex = 0; labelIndex < map.getNumLabels(); ++labelIndex) { double ex = constraint.expectation[labelIndex]/constraint.count; featureValue -= Math.pow(constraint.target[labelIndex] - ex,2); } assert(!Double.isNaN(featureValue) && !Double.isInfinite(featureValue)); value += featureValue * constraint.weight; } } return value; } protected class OneLabelGEL2Constraint extends OneLabelGEConstraint { public OneLabelGEL2Constraint(double[] target, double weight) { super(target,weight); } public double getValue(int li) { assert(this.count != 0); return 2 * this.weight * (target[li] / count - expectation[li] / (count * count)); } } }