/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.pr.constraints; import gnu.trove.TIntArrayList; import gnu.trove.TIntIntHashMap; import gnu.trove.TIntObjectHashMap; import java.util.BitSet; import cc.mallet.fst.semi_supervised.StateLabelMap; import cc.mallet.types.FeatureVector; import cc.mallet.types.FeatureVectorSequence; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; /** * A set of constraints on distributions over single * labels conditioned on the presence of input features. * * This is to be used with PR, and penalizes * L_2^2 difference from target expectations. * * Multiple constraints are grouped together here * to make things more efficient. * * @author Gregory Druck */ public class OneLabelL2PRConstraints implements PRConstraint { // maps between input feature indices and constraints protected TIntObjectHashMap<OneLabelPRConstraint> constraints; // maps between input feature indices and constraint indices protected TIntIntHashMap constraintIndices; protected StateLabelMap map; protected boolean normalized; // cache of set of constrained features that fire at last FeatureVector // provided in preprocess call protected TIntArrayList cache; public OneLabelL2PRConstraints(boolean normalized) { this.constraints = new TIntObjectHashMap<OneLabelPRConstraint>(); this.constraintIndices = new TIntIntHashMap(); this.cache = new TIntArrayList(); this.normalized = normalized; } protected OneLabelL2PRConstraints(TIntObjectHashMap<OneLabelPRConstraint> constraints, TIntIntHashMap constraintIndices, StateLabelMap map, boolean normalized) { this.constraints = new TIntObjectHashMap<OneLabelPRConstraint>(); for (int key : constraints.keys()) { this.constraints.put(key, constraints.get(key).copy()); } //this.constraints = constraints; this.constraintIndices = constraintIndices; this.map = map; this.cache = new TIntArrayList(); this.normalized = normalized; } public PRConstraint copy() { return new OneLabelL2PRConstraints(this.constraints, this.constraintIndices, this.map, this.normalized); } public void addConstraint(int fi, double[] target, double weight) { constraints.put(fi,new OneLabelPRConstraint(target,weight)); constraintIndices.put(fi, constraintIndices.size()); } public int numDimensions() { assert(map != null); return map.getNumLabels() * constraints.size(); } public boolean isOneStateConstraint() { return true; } public void setStateLabelMap(StateLabelMap map) { this.map = map; } public void preProcess(FeatureVector fv) { cache.resetQuick(); int fi; // cache constrained input features for (int loc = 0; loc < fv.numLocations(); loc++) { fi = fv.indexAtLocation(loc); if (constraints.containsKey(fi)) { cache.add(fi); } } } // find examples that contain constrained input features public BitSet preProcess(InstanceList data) { // count int ii = 0; int fi; FeatureVector fv; BitSet bitSet = new BitSet(data.size()); for (Instance instance : data) { FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData(); for (int ip = 0; ip < fvs.size(); ip++) { fv = fvs.get(ip); for (int loc = 0; loc < fv.numLocations(); loc++) { fi = fv.indexAtLocation(loc); if (constraints.containsKey(fi)) { constraints.get(fi).count += 1; bitSet.set(ii); } } } ii++; } return bitSet; } public double getScore(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double[] parameters) { double dot = 0; int li2 = map.getLabelIndex(destIndex); for (int i = 0; i < cache.size(); i++) { int j = constraintIndices.get(cache.getQuick(i)); // TODO binary features if (normalized) { dot += parameters[j + constraints.size() * li2] / constraints.get(cache.getQuick(i)).count; } else { dot += parameters[j + constraints.size() * li2]; } } return dot; } public void incrementExpectations(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double prob) { int li2 = map.getLabelIndex(destIndex); for (int i = 0; i < cache.size(); i++) { constraints.get(cache.getQuick(i)).expectation[li2] += prob; } } public void getExpectations(double[] expectations) { assert(expectations.length == numDimensions()); for (int fi : constraintIndices.keys()) { int ci = constraintIndices.get(fi); OneLabelPRConstraint constraint = constraints.get(fi); for (int li = 0; li < constraint.expectation.length; li++) { expectations[ci + li * constraints.size()] = constraint.expectation[li]; } } } public void addExpectations(double[] expectations) { assert(expectations.length == numDimensions()); for (int fi : constraintIndices.keys()) { int ci = constraintIndices.get(fi); OneLabelPRConstraint constraint = constraints.get(fi); for (int li = 0; li < constraint.expectation.length; li++) { constraint.expectation[li] += expectations[ci + li * constraints.size()]; } } } public void zeroExpectations() { for (int fi : constraints.keys()) { constraints.get(fi).expectation = new double[map.getNumLabels()]; } } public double getAuxiliaryValueContribution(double[] parameters) { double value = 0; for (int fi : constraints.keys()) { int ci = constraintIndices.get(fi); for (int li = 0; li < map.getNumLabels(); li++) { double param = parameters[ci + li * constraints.size()]; value += constraints.get(fi).target[li] * param - (param * param) / (2 * constraints.get(fi).weight); } } return value; } // TODO public double getCompleteValueContribution(double[] parameters) { double value = 0; for (int fi : constraints.keys()) { OneLabelPRConstraint constraint = constraints.get(fi); for (int li = 0; li < map.getNumLabels(); li++) { if (normalized) { value += constraint.weight * Math.pow(constraint.target[li] - constraint.expectation[li]/constraint.count,2) / 2; } else { value += constraint.weight * Math.pow(constraint.target[li] - constraint.expectation[li],2) / 2; } } } return value; } public void getGradient(double[] parameters, double[] gradient) { for (int fi : constraints.keys()) { int ci = constraintIndices.get(fi); OneLabelPRConstraint constraint = constraints.get(fi); for (int li = 0; li < map.getNumLabels(); li++) { if (normalized) { gradient[ci + li * constraints.size()] = constraint.target[li] - constraint.expectation[li] / constraint.count - parameters[ci + li * constraints.size()] / constraint.weight; } else { gradient[ci + li * constraints.size()] = constraint.target[li] - constraint.expectation[li] - parameters[ci + li * constraints.size()] / constraint.weight; } } } } protected class OneLabelPRConstraint { protected double[] target; protected double[] expectation; protected double count; protected double weight; public OneLabelPRConstraint(double[] target, double weight) { this.target = target; this.weight = weight; this.expectation = null; this.count = 0; } public OneLabelPRConstraint copy() { OneLabelPRConstraint copy = new OneLabelPRConstraint(target,weight); copy.count = count; copy.expectation = new double[target.length]; return copy; } } }