package edu.stanford.nlp.optimization; import edu.stanford.nlp.util.logging.Redwood; import java.io.Serializable; /** * Stochastic Gradient Descent To Quasi Newton Minimizer * * An experimental minimizer which takes a stochastic function (one implementing AbstractStochasticCachingDiffFunction) * and executes SGD for the first couple passes. During the final iterations a series of approximate hessian vector * products are built up. These are then passed to the QNminimizer so that it can start right up without the typical * delay. * * Note [2012] The basic idea here is good, but the original ScaledSGDMinimizer wasn't efficient, and so this would * be much more useful if rewritten to use the good StochasticInPlaceMinimizer instead. * * @author <a href="mailto:akleeman@stanford.edu">Alex Kleeman</a> * @version 1.0 * @since 1.0 */ public class SGDToQNMinimizer implements Minimizer<DiffFunction>, Serializable { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(SGDToQNMinimizer.class); private static final long serialVersionUID = -7551807670291500396L; // private int k; private final int bSize; private boolean quiet = false; public boolean outputIterationsToFile = false; // public int outputFrequency = 10; public double gain = 0.1; // private List<double[]> gradList = null; // private List<double[]> yList = null; // private List<double[]> sList = null; // private List<double[]> tmpYList = null; // private List<double[]> tmpSList = null; // private int memory = 5; public int SGDPasses = -1; public int QNPasses = -1; private final int hessSampleSize; private final int QNMem; public SGDToQNMinimizer(double SGDGain, int batchSize, int SGDPasses, int QNPasses){ this(SGDGain, batchSize, SGDPasses, QNPasses, 50, 10); } public SGDToQNMinimizer(double SGDGain, int batchSize, int sgdPasses, int qnPasses, int hessSamples, int QNMem) { this(SGDGain, batchSize, sgdPasses, qnPasses, hessSamples, QNMem, false); } public SGDToQNMinimizer(double SGDGain, int batchSize, int sgdPasses, int qnPasses, int hessSamples, int QNMem, boolean outputToFile) { this.gain = SGDGain; this.bSize = batchSize; this.SGDPasses = sgdPasses; this.QNPasses = qnPasses; this.hessSampleSize = hessSamples; this.QNMem = QNMem; this.outputIterationsToFile = outputToFile; } public void shutUp() { this.quiet = true; } protected String getName() { int g = (int) (gain * 1000); return "SGD2QN" + bSize + "_g" + g; } public double[] minimize(DiffFunction function, double functionTolerance, double[] initial) { return minimize(function,functionTolerance,initial,-1); } public double[] minimize(DiffFunction function, double functionTolerance, double[] initial, int maxIterations) { sayln("SGDToQNMinimizer called on function of " + function.domainDimension() + " variables;"); // check for stochastic derivatives if (!(function instanceof AbstractStochasticCachingDiffFunction)) { throw new UnsupportedOperationException(); } AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction) function; dfunction.method = StochasticCalculateMethods.GradientOnly; ScaledSGDMinimizer sgd = new ScaledSGDMinimizer(this.gain,this.bSize,this.SGDPasses,1,this.outputIterationsToFile); QNMinimizer qn = new QNMinimizer(this.QNMem,true); double[] x = sgd.minimize(dfunction, functionTolerance, initial, this.SGDPasses); QNMinimizer.QNInfo qnInfo = qn.new QNInfo(sgd.sList , sgd.yList); qnInfo.d = sgd.diag; qn.minimize(dfunction, functionTolerance, x, this.QNPasses, qnInfo); log.info(""); log.info("Minimization complete."); log.info(""); log.info("Exiting for Debug"); return x; } private void sayln(String s) { if (!quiet) { log.info(s); } } private void say(String s) { if (!quiet) { log.info(s); } } }