package edu.stanford.nlp.optimization; import edu.stanford.nlp.util.logging.Redwood; import java.util.ArrayList; import java.util.List; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.util.Pair; /** * Online Limited-Memory Quasi-Newton BFGS implementation based on the algorithms in * <p> * Nocedal, Jorge, and Stephen J. Wright. 2000. Numerical Optimization. Springer. pp. 224-- * <p> * and modified to the online version presented in * <p> * A Stocahstic Quasi-Newton Method for Online Convex Optimization * Schraudolph, Yu, Gunter (2007) * <p> * As of now, it requires a * Stochastic differentiable function (AbstractStochasticCachingDiffFunction) as input. * <p/> * The basic way to use the minimizer is with a null constructor, then * the simple minimize method: * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! * THIS IS NOT UPDATE FOR THE STOCHASTIC VERSION YET. * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! * <p/> * <p><code>Minimizer qnm = new QNMinimizer();</code> * <br><code>DiffFunction df = new SomeDiffFunction();</code> * <br><code>double tol = 1e-4;</code> * <br><code>double[] initial = getInitialGuess();</code> * <br><code>double[] minimum = qnm.minimize(df,tol,initial);</code> * <p/> * <p/> * If you do not choose a value of M, it will use the max amount of memory * available, up to M of 20. This will slow things down a bit at first due * to forced garbage collection, but is probably faster overall b/c you are * guaranteed the largest possible M. * * The Stochastic version was written by Alex Kleeman, but about 95% of the code * was taken directly from the previous QNMinimizer written mostly by Jenny. * * @author <a href="mailto:jrfinkel@stanford.edu">Jenny Finkel</a> * @author Galen Andrew * @author <a href="mailto:akleeman@stanford.edu">Alex Kleeman</a> * @version 1.0 * @since 1.0 */ public class SQNMinimizer<T extends Function> extends StochasticMinimizer<T> { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(SQNMinimizer.class); private int M = 0; private double lambda = 1.0; private double cPosDef = 1; private double epsilon = 1e-10; private List<double[]> sList = new ArrayList<>(); private List<double[]> yList = new ArrayList<>(); private List<Double> roList = new ArrayList<>(); double[] dir, s,y; double ro; public void setM(int m) { M = m; } public SQNMinimizer(int m) { M = m; } public SQNMinimizer() { } public SQNMinimizer(int mem,double initialGain, int batchSize,boolean output) { gain = initialGain; bSize = batchSize; this.M = mem; this.outputIterationsToFile = output; } @Override public String getName(){ int g = (int) (gain*1000.0); return "SQN" + bSize + "_g" + g ; } // computes d = a + b * c private static double[] plusAndConstMult(double[] a, double[] b, double c, double[] d) { for (int i = 0; i < a.length; i++) { d[i] = a[i] + c * b[i]; } return d; } @Override public Pair<Integer,Double> tune( edu.stanford.nlp.optimization.Function function,double[] initial, long msPerTest){ log.info("No tuning set yet"); return new Pair<>(bSize, gain); } private void computeDir(double[] dir, double[] fg) throws SQNMinimizer.SurpriseConvergence { System.arraycopy(fg, 0, dir, 0, fg.length); int mmm = sList.size(); double[] as = new double[mmm]; double[] factors = new double[dir.length]; for (int i = mmm - 1; i >= 0; i--) { as[i] = roList.get(i) * ArrayMath.innerProduct(sList.get(i), dir); plusAndConstMult(dir, yList.get(i), -as[i], dir); } // multiply by hessian approximation if (mmm != 0) { double[] y = yList.get(mmm - 1); double yDotY = ArrayMath.innerProduct(y, y); if (yDotY == 0) { throw new SQNMinimizer.SurpriseConvergence("Y is 0!!"); } double gamma = ArrayMath.innerProduct(sList.get(mmm - 1), y) / yDotY; ArrayMath.multiplyInPlace(dir, gamma); }else if(mmm == 0){ //This is a safety feature preventing too large of an initial step (see Yu Schraudolph Gunter) ArrayMath.multiplyInPlace(dir,epsilon); } for (int i = 0; i < mmm; i++) { double b = roList.get(i) * ArrayMath.innerProduct(yList.get(i), dir); plusAndConstMult(dir, sList.get(i), cPosDef*as[i] - b, dir); plusAndConstMult(ArrayMath.pairwiseMultiply(yList.get(i),sList.get(i)),factors,1,factors); } ArrayMath.multiplyInPlace(dir, -1); } @Override protected void init(AbstractStochasticCachingDiffFunction func){ sList = new ArrayList<>(); yList = new ArrayList<>(); dir = new double[func.domainDimension()]; } @Override protected void takeStep(AbstractStochasticCachingDiffFunction dfunction){ try { computeDir(dir, newGrad); } catch (SQNMinimizer.SurpriseConvergence s) { clearStuff(); } double thisGain = gain*gainSchedule(k,5*numBatches); for(int i = 0; i < x.length; i++){ newX[i] = x[i] + thisGain*dir[i]; } //Get a new pair... say(" A "); if (M > 0 && sList.size() == M || sList.size() == M) { s = sList.remove(0); y = yList.remove(0); } else { s = new double[x.length]; y = new double[x.length]; } dfunction.recalculatePrevBatch = true; System.arraycopy(dfunction.derivativeAt(newX,bSize),0,y,0,grad.length); // compute s_k, y_k ro = 0; for(int i=0;i<x.length;i++){ s[i] = newX[i] - x[i]; y[i] = y[i] - newGrad[i] + lambda*s[i]; ro += s[i]*y[i]; } ro = 1.0 / ro; sList.add(s); yList.add(y); roList.add(ro); } private void clearStuff() { sList = null; yList = null; roList = null; } private static class SurpriseConvergence extends Throwable { /** * */ private static final long serialVersionUID = -4377976289620760327L; public SurpriseConvergence(String s) { super(s); } } }