/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised.pr; import java.util.ArrayList; import java.util.Iterator; import cc.mallet.fst.CRF; import cc.mallet.fst.Transducer; import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint; import cc.mallet.types.FeatureVector; import cc.mallet.types.Sequence; /** * Auxiliar model (q) for E-step/I-projection in Posterior Regularization (PR). * * @author Gregory Druck */ public class PRAuxiliaryModel extends Transducer { private static final long serialVersionUID = 1L; private int numParameters; private double[][] parameters; private ArrayList<PRConstraint> constraints; private CRF baseModel; public PRAuxiliaryModel(CRF baseModel, ArrayList<PRConstraint> constraints) { this.baseModel = baseModel; this.constraints = constraints; int index = 0; this.parameters = new double[constraints.size()][]; for (PRConstraint constraint : constraints) { parameters[index] = new double[constraint.numDimensions()]; index++; numParameters += constraint.numDimensions(); } } private PRAuxiliaryModel(CRF baseModel, ArrayList<PRConstraint> constraints, double[][] parameters) { this.baseModel = baseModel; this.constraints = constraints; this.parameters = parameters; for (PRConstraint constraint : constraints) { numParameters += constraint.numDimensions(); } } public PRAuxiliaryModel copy() { ArrayList<PRConstraint> copy = new ArrayList<PRConstraint>(); for (PRConstraint constraint : constraints) { copy.add(constraint.copy()); } // parameters should not be changed in copy, // so we can use the same parameters array return new PRAuxiliaryModel(baseModel,copy,parameters); } public void preProcess(int index, int position, Sequence input) { for (PRConstraint constraint : constraints) { constraint.preProcess((FeatureVector)input.get(position)); } } public double getValue() { double value = 0; int index = 0; for (PRConstraint constraint : constraints) { value += constraint.getAuxiliaryValueContribution(parameters[index]); index++; } return value; } public double getCompleteValueContribution() { double value = 0; int index = 0; for (PRConstraint constraint : constraints) { value += constraint.getCompleteValueContribution(parameters[index]); index++; } return value; } public void getValueGradient(double[] gradient) { int index = 0; int start = 0; for (PRConstraint constraint: constraints) { double[] constraintGradient = new double[constraint.numDimensions()]; constraint.getGradient(parameters[index], constraintGradient); System.arraycopy(constraintGradient, 0, gradient, start, constraintGradient.length); start += constraint.numDimensions(); index++; } } public double getWeight(int index, int position, Sequence input, TransitionIterator iter) { double weight = 0; int si1 = iter.getSourceState().getIndex(); int si2 = iter.getDestinationState().getIndex(); int constrIndex = 0; for (PRConstraint constraint : constraints) { weight += constraint.getScore((FeatureVector)input.get(position), position, si1, si2, parameters[constrIndex]); constrIndex++; } return weight; } public void incrementTransition(int index, int position, Sequence input, TransitionIterator iter, double prob) { int si1 = iter.getSourceState().getIndex(); int si2 = iter.getDestinationState().getIndex(); for (PRConstraint constraint : constraints) { constraint.incrementExpectations((FeatureVector)input.get(position), position, si1, si2, prob); } } public void zeroExpectations() { for (PRConstraint constraint : constraints) { constraint.zeroExpectations(); } } public int numParameters() { return numParameters; } public void getParameters(double[] params) { assert(params.length == numParameters); int start = 0; for (int i = 0; i < this.parameters.length; i++) { System.arraycopy(this.parameters[i], 0, params, start, this.parameters[i].length); start += this.parameters[i].length; } } public double getParameter(int index) { assert(index > 0); int constrIndex = 0; for (PRConstraint constraint : constraints) { if (index < constraint.numDimensions()) { return parameters[constrIndex][index]; } constrIndex++; index -= constraint.numDimensions(); } throw new RuntimeException("index not found: " + index); } public void setParameters(double[] params) { assert(params.length == numParameters); int start = 0; for (int i = 0; i < parameters.length; i++) { System.arraycopy(params, start, this.parameters[i], 0, this.parameters[i].length); start += parameters[i].length; } } public void setParameter(int index, double value) { assert(index > 0); int constrIndex = 0; for (PRConstraint constraint : constraints) { if (index < constraint.numDimensions()) { parameters[constrIndex][index] = value; return; } constrIndex++; index -= constraint.numDimensions(); } throw new RuntimeException("index not found: " + index); } public int numConstraints() { return constraints.size(); } public PRConstraint getConstraint(int index) { return constraints.get(index); } public CRF getBaseModel() { return baseModel; } @Override public int numStates() { return baseModel.numStates(); } @Override public State getState(int index) { return baseModel.getState(index); } @Override public Iterator initialStateIterator() { return baseModel.initialStateIterator(); } }