package cc.mallet.fst; import java.util.logging.Logger; import cc.mallet.optimize.Optimizable; import cc.mallet.types.MatrixOps; import cc.mallet.util.MalletLogger; /** * A CRF objective function that is the sum of multiple * objective functions that implement Optimizable.ByGradientValue. * * @author Gregory Druck * @author Gaurav Chandalia */ public class CRFOptimizableByGradientValues implements Optimizable.ByGradientValue { private static Logger logger = MalletLogger.getLogger(CRFOptimizableByGradientValues.class.getName()); private int cachedValueWeightsStamp; private int cachedGradientWeightsStamp; private double cachedValue = Double.NEGATIVE_INFINITY; private double[] cachedGradient; private Optimizable.ByGradientValue[] optimizables; private CRF crf; /** * @param crf CRF whose parameters we wish to estimate. * @param opts Optimizable.ByGradientValue objective functions. * * Parameters are estimated by maximizing the sum of the individual * objective functions. */ public CRFOptimizableByGradientValues (CRF crf, Optimizable.ByGradientValue[] opts) { this.crf = crf; this.optimizables = opts; this.cachedGradient = new double[crf.parameters.getNumFactors()]; this.cachedValueWeightsStamp = -1; this.cachedGradientWeightsStamp = -1; } public int getNumParameters () { return crf.parameters.getNumFactors(); } public void getParameters (double[] buffer) { crf.parameters.getParameters(buffer); } public double getParameter (int index) { return crf.parameters.getParameter(index); } public void setParameters (double [] buff) { crf.parameters.setParameters(buff); crf.weightsValueChanged(); } public void setParameter (int index, double value) { crf.parameters.setParameter(index, value); crf.weightsValueChanged(); } /** Returns the log probability of the training sequence labels and the prior over parameters. */ public double getValue () { if (crf.weightsValueChangeStamp != cachedValueWeightsStamp) { // The cached value is not up to date; it was calculated for a different set of CRF weights. cachedValue = 0; for (int i = 0; i < optimizables.length; i++) cachedValue += optimizables[i].getValue(); cachedValueWeightsStamp = crf.weightsValueChangeStamp; // cachedValue is now no longer stale logger.info ("getValue() = "+cachedValue); } return cachedValue; } public void getValueGradient (double [] buffer) { if (cachedGradientWeightsStamp != crf.weightsValueChangeStamp) { getValue (); MatrixOps.setAll(cachedGradient, 0); double[] b2 = new double[buffer.length]; for (int i = 0; i < optimizables.length; i++) { MatrixOps.setAll(b2, 0); optimizables[i].getValueGradient(b2); MatrixOps.plusEquals(cachedGradient, b2); } cachedGradientWeightsStamp = crf.weightsValueChangeStamp; } System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length); } }