/** * */ package edu.berkeley.nlp.math; /** * @author petrov * Orthant-Wise L-BFGS * */ public class OW_LBFGSMinimizer extends LBFGSMinimizer implements GradientMinimizer { /** * @param iterations */ public OW_LBFGSMinimizer(int iterations) { super(iterations); } public double[] minimize(DifferentiableRegularizableFunction function, double[] initial, double tolerance) { BacktrackingLineSearcher lineSearcher = new BacktrackingLineSearcher(); lineSearcher.sufficientDecreaseConstant = 0; double[] guess = DoubleArrays.clone(initial); for (int iteration = 0; iteration < maxIterations; iteration++) { double[] derivative = function.derivativeAt(guess); double value = function.valueAt(guess); double[] direction = getSearchDirection(function.dimension(), derivative); double[] unregularizedDerivative = function.unregularizedDerivativeAt(guess); double[] orthant = getOrthant(initial, derivative); DoubleArrays.project(direction, derivative);//orthant);// //p^k: project search direction onto orthant defined by gradient DoubleArrays.scale(direction, -1.0); // System.out.println(" Derivative is: "+DoubleArrays.toString(derivative, 100)); // DoubleArrays.assign(direction, derivative); // System.out.println(" Looking in direction: "+DoubleArrays.toString(direction, 100)); if (iteration == 0) lineSearcher.stepSizeMultiplier = initialStepSizeMultiplier; else lineSearcher.stepSizeMultiplier = stepSizeMultiplier; double[] nextGuess = lineSearcher.minimize(function, guess, direction, true); double nextValue = function.valueAt(nextGuess); // double[] nextDerivative = function.derivativeAt(nextGuess); double[] unregularizedNextDerivative = function.unregularizedDerivativeAt(nextGuess); System.out.printf("Iteration %d ended with value %.6f\n", iteration, nextValue); if (iteration >= minIterations && converged(value, nextValue, tolerance)) return nextGuess; // update with unregularized derivatives! updateHistories(guess, nextGuess, unregularizedDerivative, unregularizedNextDerivative); guess = nextGuess; value = nextValue; // derivative = nextDerivative; unregularizedDerivative = unregularizedNextDerivative; if (iterCallbackFunction != null) { iterCallbackFunction.callback(guess,iteration); } } //System.err.println("LBFGSMinimizer.minimize: Exceeded maxIterations without converging."); return guess; } /** * @param initial * @param derivative * @return */ private double[] getOrthant(double[] initial, double[] derivative) { double[] orthant=new double[initial.length]; for (int i=0; i<initial.length; i++) { if (initial[i]!=0) orthant[i] = Math.signum(initial[i]); else orthant[i] = Math.signum(-derivative[i]); } return orthant; } }