/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.util; import cc.mallet.optimize.Optimizable; import cc.mallet.util.MalletLogger; import java.util.logging.Logger; /** * Created: Aug 27, 2004 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: CachingOptimizable.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $ */ public class CachingOptimizable { private static abstract class Base implements Optimizable { static final Logger logger = MalletLogger.getLogger (CachingOptimizable.class.getName ()); double cachedValue = -123456789; double[] cachedGradient; protected boolean cachedValueStale = true; protected boolean cachedGradientStale = true; protected abstract void setParametersInternal (double[] params); public void setParameters (double[] params) { if (params.length != getNumParameters ()) throw new IllegalArgumentException ("Argument is not of the " + " correct dimensions"); cachedValueStale = cachedGradientStale = true; setParametersInternal (params); } /** * Sets one parameter of the maximizable object. This default implementation * inefficiently uses both <tt>getParameters(double[])</tt> and <tt>setParameters(double[])</tt>. * Subclasses may override this method for more efficient implemetnations. * * @param index * @param value */ public void setParameter (int index, double value) { cachedValueStale = cachedGradientStale = true; double[] params = new double[getNumParameters ()]; getParameters (params); params[index] = value; setParametersInternal (params); } /** * Returns one parameter of the maximizable object. This default implementation * inefficiently calls through to <tt>getParameters(double[])</tt>. * Subclasses may override this method for more efficient implemetnations. * * @param index * @return The value of parameter <tt>index<tt> */ public double getParameter (int index) { double[] params = new double[getNumParameters ()]; getParameters (params); return params[index]; } public void forceStale () { cachedValueStale = cachedGradientStale = true; } } /**/ public static abstract class ByGradient extends Base implements Optimizable.ByGradientValue { protected abstract double computeValue (); protected abstract void computeValueGradient (double[] buffer); public void getValueGradient (double[] buffer) { if (buffer.length != getNumParameters ()) throw new IllegalArgumentException ("Argument is not of the " + " correct dimensions"); if (cachedValueStale) { cachedValue = computeValue (); cachedValueStale = false; } if (cachedGradientStale) { if (cachedGradient == null) { cachedGradient = new double[getNumParameters ()]; } computeValueGradient (cachedGradient); cachedGradientStale = false; } System.arraycopy (cachedGradient, 0, buffer, 0, cachedGradient.length); } public double getValue () { if (cachedValueStale) { long startTime = System.currentTimeMillis(); cachedValue = computeValue (); long endTime = System.currentTimeMillis(); logger.info ("Optimizable computeValue time (ms) ="+(endTime-startTime)); logger.info ("computeValue() = " + cachedValue); cachedValueStale = false; } return cachedValue; } /** * Sets the cached gradient. This is useful for subclasses that * need to compute the value and the gradient at the same time. * If they call this method in computeValue(), then * their computeValueGradient() will never be called. * * @param gradient */ protected void setCachedGradient (double[] gradient) { if (cachedGradient == null) { cachedGradient = new double[getNumParameters ()]; } System.arraycopy (gradient, 0, cachedGradient, 0, gradient.length); cachedGradientStale = false; } } public static abstract class ByBatchGradient extends Base implements Optimizable.ByBatchGradient { private int lastIndex; private int[] lastAssns; public void getBatchValueGradient (double[] buffer, int batchIndex, int[] batchAssignments) { if (buffer.length != getNumParameters ()) throw new IllegalArgumentException ("Argument is not of the " + " correct dimensions"); if ((batchIndex != lastIndex) || (batchAssignments != lastAssns)) { forceStale (); lastIndex = batchIndex; lastAssns = batchAssignments; } if (cachedValueStale) { cachedValue = computeBatchValue (batchIndex, batchAssignments); cachedValueStale = false; } if (cachedGradientStale) { if (cachedGradient == null) { cachedGradient = new double[getNumParameters ()]; } computeBatchGradient (cachedGradient, batchIndex, batchAssignments); cachedGradientStale = false; } System.arraycopy (cachedGradient, 0, buffer, 0, cachedGradient.length); } public double getBatchValue (int batchIndex, int[] batchAssignments) { if ((batchIndex != lastIndex) || (batchAssignments != lastAssns)) { forceStale (); lastIndex = batchIndex; lastAssns = batchAssignments; } if (cachedValueStale) { cachedValue = computeBatchValue (batchIndex, batchAssignments); logger.info ("computeValue() = " + cachedValue); cachedValueStale = false; } return cachedValue; } protected abstract double computeBatchValue (int batchIndex, int[] batchAssignments); protected abstract void computeBatchGradient (double[] buffer, int batchIndex, int[] batchAssignments); } }