package edu.stanford.nlp.optimization; import java.io.Serializable; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.List; import edu.stanford.nlp.sequences.SeqClassifierFlags; /** * <p> * Stochastic Gradient Descent To Quasi Newton Minimizer * * An experimental minimizer which takes a stochastic function (one implementing AbstractStochasticCachingDiffFunction) * and executes SGD for the first couple passes, During the final iterations a series of approximate hessian vector * products are built up... These are then passed to the QNminimizer so that it can start right up without the typical * delay. * * @author <a href="mailto:akleeman@stanford.edu">Alex Kleeman</a> * @version 1.0 * @since 1.0 */ public class SGDToQNMinimizer implements Minimizer<DiffFunction>,Serializable { /** * */ private static final long serialVersionUID = -7551807670291500396L; private int k; private int bSize = 15; private boolean quiet = false; public boolean outputIterationsToFile = false; public int outputFrequency = 10; public double gain = 0.1; private List<double[]> gradList = null; private List<double[]> yList = null; private List<double[]> sList = null; private List<double[]> tmpYList = null; private List<double[]> tmpSList = null; private int memory = 5; public int SGDPasses = -1; public int QNPasses = -1; private int hessSampleSize = 50; private int QNMem = 10; private boolean toTest = false; public void shutUp() { this.quiet = true; } public void setBatchSize(int batchSize) { bSize = batchSize; } private static NumberFormat nf = new DecimalFormat("0.000E0"); public SGDToQNMinimizer(SeqClassifierFlags flags){ this.bSize = flags.stochasticBatchSize; this.gain = flags.initialGain; this.SGDPasses = flags.SGDPasses; this.QNPasses = flags.QNPasses; this.QNMem = flags.QNsize; this.outputIterationsToFile = flags.outputIterationsToFile; this.toTest = flags.testObjFunction; this.hessSampleSize = flags.SGD2QNhessSamples; } public SGDToQNMinimizer(double SGDGain, int batchSize, int sgdPasses, int qnPasses, int hessSamples, int QNMem){ this( SGDGain, batchSize, sgdPasses, qnPasses, hessSamples, QNMem,false); } public SGDToQNMinimizer(double SGDGain, int batchSize, int sgdPasses, int qnPasses, int hessSamples, int QNMem,boolean outputToFile){ this.bSize = batchSize; this.gain = SGDGain; this.SGDPasses = sgdPasses; this.QNPasses = qnPasses; this.QNMem = QNMem; this.outputIterationsToFile = outputToFile; this.hessSampleSize = hessSamples; } public SGDToQNMinimizer(double SGDGain, int batchSize, int SGDPasses, int QNPasses){ this(SGDGain,batchSize,SGDPasses,QNPasses,50,10); } public void setQNMem(int mem){ QNMem = mem; } public void setHessSampleSize(int size){ hessSampleSize = size; } protected String getName(){ int g = (int) (gain * 1000); return "SGD2QN" + bSize + "_g" + g; } public double[] minimize(DiffFunction function, double functionTolerance, double[] initial) { return minimize(function,functionTolerance,initial,-1); } public double[] minimize(DiffFunction function, double functionTolerance, double[] initial, int maxIterations) { sayln("SGDToQNMinimizer called on function of " + function.domainDimension() + " variables;"); // check for stochastic derivatives if (!(function instanceof AbstractStochasticCachingDiffFunction)) { throw new UnsupportedOperationException(); } AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction) function; dfunction.method = StochasticCalculateMethods.GradientOnly; ScaledSGDMinimizer sgd = new ScaledSGDMinimizer(this.gain,this.bSize,this.SGDPasses,1,this.outputIterationsToFile); QNMinimizer qn = new QNMinimizer(this.QNMem,true); double[] x = sgd.minimize(dfunction, functionTolerance, initial, this.SGDPasses); QNMinimizer.QNInfo qnInfo = qn.new QNInfo(sgd.sList , sgd.yList); qnInfo.d = sgd.diag; qn.minimize(dfunction, functionTolerance, x, this.QNPasses, qnInfo); System.err.println(""); System.err.println("Minimization complete."); System.err.println(""); System.err.println("Exiting for Debug"); return x; } private void sayln(String s) { if (!quiet) { System.err.println(s); } } private void say(String s) { if (!quiet) { System.err.print(s); } } }