/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.pr.constraints; import gnu.trove.TIntArrayList; import gnu.trove.TIntObjectHashMap; import java.util.ArrayList; import java.util.BitSet; import java.util.HashMap; import cc.mallet.fst.semi_supervised.StateLabelMap; import cc.mallet.types.FeatureVector; import cc.mallet.types.FeatureVectorSequence; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; /** * A set of constraints on individual input feature label pairs. * * This is to be used with PR, and penalizes * L_2^2 difference from target expectations. * * Multiple constraints are grouped together here * to make things more efficient. * * @author Gregory Druck */ public class OneLabelL2IndPRConstraints implements PRConstraint { protected boolean normalized; protected int numDimensions; // maps between input feature indices and constraints protected TIntObjectHashMap<OneLabelL2IndPRConstraint> constraints; protected StateLabelMap map; // cache of set of constrained features that fire at last FeatureVector // provided in preprocess call protected TIntArrayList cache; public OneLabelL2IndPRConstraints(boolean normalized) { this.normalized = normalized; this.numDimensions = 0; this.constraints = new TIntObjectHashMap<OneLabelL2IndPRConstraint>(); // this will be set by the PRTrainer this.map = null; this.cache = new TIntArrayList(); } protected OneLabelL2IndPRConstraints(TIntObjectHashMap<OneLabelL2IndPRConstraint> constraints, StateLabelMap map, boolean normalized) { this.normalized = normalized; this.numDimensions = 0; // copy constraints this.constraints = new TIntObjectHashMap<OneLabelL2IndPRConstraint>(); for (int key : constraints.keys()) { this.constraints.put(key, constraints.get(key).copy()); numDimensions += constraints.get(key).getNumConstrainedLabels(); } this.map = map; this.cache = new TIntArrayList(); } public PRConstraint copy() { return new OneLabelL2IndPRConstraints(this.constraints, this.map, this.normalized); } public void addConstraint(int fi, int li, double target, double weight) { if (!constraints.containsKey(fi)) { constraints.put(fi,new OneLabelL2IndPRConstraint()); } constraints.get(fi).add(li, target, weight, numDimensions); numDimensions++; } public int numDimensions() { return numDimensions; } public void setStateLabelMap(StateLabelMap map) { this.map = map; } public void preProcess(FeatureVector fv) { cache.resetQuick(); int fi; // cache constrained input features for (int loc = 0; loc < fv.numLocations(); loc++) { fi = fv.indexAtLocation(loc); if (constraints.containsKey(fi)) { cache.add(fi); } } } // find examples that contain constrained input features public BitSet preProcess(InstanceList data) { // count int ii = 0; int fi; FeatureVector fv; BitSet bitSet = new BitSet(data.size()); for (Instance instance : data) { FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData(); for (int ip = 0; ip < fvs.size(); ip++) { fv = fvs.get(ip); for (int loc = 0; loc < fv.numLocations(); loc++) { fi = fv.indexAtLocation(loc); if (constraints.containsKey(fi)) { constraints.get(fi).count += 1; bitSet.set(ii); } } } ii++; } return bitSet; } public double getScore(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double[] parameters) { double dot = 0; int li2 = map.getLabelIndex(destIndex); for (int i = 0; i < cache.size(); i++) { int fi = cache.getQuick(i); OneLabelL2IndPRConstraint constraint = constraints.get(fi); dot += constraint.getScore(li2, parameters); } return dot; } public void incrementExpectations(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double prob) { int li2 = map.getLabelIndex(destIndex); for (int i = 0; i < cache.size(); i++) { constraints.get(cache.getQuick(i)).incrementExpectation(li2, prob); } } public void getExpectations(double[] expectations) { assert(expectations.length == numDimensions()) : expectations.length + " " + numDimensions(); for (int fi : constraints.keys()) { constraints.get(fi).getExpectations(expectations); } } public void addExpectations(double[] expectations) { assert(expectations.length == numDimensions()); for (int fi : constraints.keys()) { constraints.get(fi).addExpectations(expectations); } } public void zeroExpectations() { for (int fi : constraints.keys()) { constraints.get(fi).zeroExpectation(); } } public double getAuxiliaryValueContribution(double[] parameters) { double value = 0; for (int fi : constraints.keys()) { OneLabelL2IndPRConstraint constraint = constraints.get(fi); value += constraint.getProjectionValueContrib(parameters); } return value; } public double getCompleteValueContribution(double[] parameters) { double value = 0; for (int fi : constraints.keys()) { OneLabelL2IndPRConstraint constraint = constraints.get(fi); value += constraint.getCompleteValueContrib(); } return value; } public void getGradient(double[] parameters, double[] gradient) { for (int fi : constraints.keys()) { OneLabelL2IndPRConstraint constraint = constraints.get(fi); constraint.getGradient(parameters, gradient); } } protected class OneLabelL2IndPRConstraint { protected int index; protected double count; protected ArrayList<Integer> labels; protected ArrayList<Integer> paramIndices; protected ArrayList<Double> targets; protected ArrayList<Double> weights; protected HashMap<Integer,Integer> labelMap; protected double[] expectation; public OneLabelL2IndPRConstraint() { index = 0; count = 0; labels = new ArrayList<Integer>(); paramIndices = new ArrayList<Integer>(); targets = new ArrayList<Double>(); weights = new ArrayList<Double>(); labelMap = new HashMap<Integer,Integer>(); } public OneLabelL2IndPRConstraint copy() { OneLabelL2IndPRConstraint copy = new OneLabelL2IndPRConstraint(); copy.index = index; copy.count = count; copy.labels = labels; copy.paramIndices = paramIndices; copy.targets = targets; copy.weights = weights; copy.labelMap = labelMap; // this will be incremented in the copy copy.expectation = new double[index]; return copy; } public void add(int label, double target, double weight, int paramIndex) { targets.add(target); weights.add(weight); labels.add(label); paramIndices.add(paramIndex); labelMap.put(label, index); index++; } public void zeroExpectation() { this.expectation = new double[labels.size()]; } public void getExpectations(double[] expectations) { for (int i = 0; i < paramIndices.size(); i++) { expectations[paramIndices.get(i)] = expectation[i]; } } public void addExpectations(double[] expectations) { for (int i = 0; i < paramIndices.size(); i++) { expectation[i] += expectations[paramIndices.get(i)]; } } public void incrementExpectation(int li, double value) { if (labelMap.containsKey(li)) { int i = labelMap.get(li); expectation[i] += value; } } public double getScore(int li, double[] parameters) { if (labelMap.containsKey(li)) { int i = labelMap.get(li); if (normalized) { return parameters[paramIndices.get(i)] / count; } else { return parameters[paramIndices.get(i)]; } } return 0; } public double getProjectionValueContrib(double[] parameters) { double value = 0; for (int i = 0; i < paramIndices.size(); i++) { double param = parameters[paramIndices.get(i)]; value += targets.get(i) * param - (param * param) / (2 * weights.get(i)); } return value; } public double getCompleteValueContrib() { double value = 0; for (int i = 0; i < paramIndices.size(); i++) { if (normalized) { value += weights.get(i) * Math.pow(targets.get(i) - expectation[i]/count,2) / 2; } else { value += weights.get(i) * Math.pow(targets.get(i) - expectation[i],2) / 2; } } return value; } public void getGradient(double[] parameters, double[] gradient) { for (int i = 0; i < paramIndices.size(); i++) { int pi = paramIndices.get(i); if (normalized) { gradient[pi] += targets.get(i) - expectation[i] / count - parameters[pi] / weights.get(i); } else { gradient[pi] += targets.get(i) - expectation[i] - parameters[pi] / weights.get(i); } } } public int getNumConstrainedLabels() { return index; } } }