package edu.stanford.nlp.loglinear.learning; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.loglinear.model.ConcatVector; /** * Created on 8/26/15. * @author keenon * <p> * Handles optimizing an AbstractDifferentiableFunction through AdaGrad guarded by backtracking. */ public class BacktrackingAdaGradOptimizer extends AbstractBatchOptimizer { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(BacktrackingAdaGradOptimizer.class); // this magic number was arrived at with relation to the CoNLL benchmark, and tinkering final static double alpha = 0.1; @Override public boolean updateWeights(ConcatVector weights, ConcatVector gradient, double logLikelihood, OptimizationState optimizationState, boolean quiet) { AdaGradOptimizationState s = (AdaGradOptimizationState) optimizationState; double logLikelihoodChange = logLikelihood - s.lastLogLikelihood; if (logLikelihoodChange == 0) { if (!quiet) log.info("\tlogLikelihood improvement = 0: quitting"); return true; } // Check if we should backtrack else if (logLikelihoodChange < 0) { // If we should, move the weights back by half, and cut the lastDerivative by half s.lastDerivative.mapInPlace((d) -> d / 2); weights.addVectorInPlace(s.lastDerivative, -1.0); if (!quiet) log.info("\tBACKTRACK..."); // if the lastDerivative norm falls below a threshold, it means we've converged if (s.lastDerivative.dotProduct(s.lastDerivative) < 1.0e-10) { if (!quiet) log.info("\tBacktracking derivative norm " + s.lastDerivative.dotProduct(s.lastDerivative) + " < 1.0e-9: quitting"); return true; } } // Apply AdaGrad else { ConcatVector squared = gradient.deepClone(); squared.mapInPlace((d) -> d * d); s.adagradAccumulator.addVectorInPlace(squared, 1.0); ConcatVector sqrt = s.adagradAccumulator.deepClone(); sqrt.mapInPlace((d) -> { if (d == 0) return alpha; else return alpha / Math.sqrt(d); }); gradient.elementwiseProductInPlace(sqrt); weights.addVectorInPlace(gradient, 1.0); // Setup for backtracking, in case necessary s.lastDerivative = gradient; s.lastLogLikelihood = logLikelihood; if (!quiet) log.info("\tLL: " + logLikelihood); } return false; } protected class AdaGradOptimizationState extends OptimizationState { ConcatVector lastDerivative = new ConcatVector(0); ConcatVector adagradAccumulator = new ConcatVector(0); double lastLogLikelihood = Double.NEGATIVE_INFINITY; } @Override protected OptimizationState getFreshOptimizationState(ConcatVector initialWeights) { return new AdaGradOptimizationState(); } }