package edu.berkeley.nlp.math;
import java.io.Serializable;
import java.util.LinkedList;
import edu.berkeley.nlp.util.CallbackFunction;
import edu.berkeley.nlp.util.Logger;
/**
* @author Dan Klein
*/
public class LBFGSMinimizer implements GradientMinimizer, Serializable
{
private static final long serialVersionUID = 36473897808840226L;
double EPS = 1e-10;
int maxIterations = 20;
int maxHistorySize = 5;
LinkedList<double[]> inputDifferenceVectorList = new LinkedList<double[]>();
LinkedList<double[]> derivativeDifferenceVectorList = new LinkedList<double[]>();
transient CallbackFunction iterCallbackFunction = null;
int minIterations = -1;
double initialStepSizeMultiplier = 0.01;
double stepSizeMultiplier = 0.5;
boolean dumpHistoryBeforeConverge = false;
boolean alreadyDumped = false;
int historyDropIters = -1;
boolean verbose = true;
public void setDumpHistoryBeforeConverge(boolean dumpHistoryBeforeConverge)
{
this.dumpHistoryBeforeConverge = dumpHistoryBeforeConverge;
}
public void setVerbose(boolean verbose)
{
this.verbose = verbose;
}
public void dumpHistoryPeriodically(int numIters)
{
this.historyDropIters = numIters;
}
public void setMinIteratons(int minIterations)
{
this.minIterations = minIterations;
}
public void setMaxIterations(int maxIterations)
{
this.maxIterations = maxIterations;
}
public void setInitialStepSizeMultiplier(double initialStepSizeMultiplier)
{
this.initialStepSizeMultiplier = initialStepSizeMultiplier;
}
public void setStepSizeMultiplier(double stepSizeMultiplier)
{
this.stepSizeMultiplier = stepSizeMultiplier;
}
public double[] getSearchDirection(int dimension, double[] derivative)
{
double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(dimension);
double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative);
return direction;
}
protected double[] getInitialInverseHessianDiagonal(int dimension)
{
double scale = 1.0;
if (derivativeDifferenceVectorList.size() >= 1)
{
double[] lastDerivativeDifference = getLastDerivativeDifference();
double[] lastInputDifference = getLastInputDifference();
double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
scale = num / den;
}
return DoubleArrays.constantArray(scale, dimension);
}
public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance)
{
return minimize(function, initial, tolerance, false);
}
public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance, boolean printProgress)
{
BacktrackingLineSearcher lineSearcher = new BacktrackingLineSearcher();
double[] guess = DoubleArrays.clone(initial);
int iteration = 0;
for (iteration = 0; iteration < maxIterations; iteration++)
{
if (historyDropIters > 0 && iteration % historyDropIters == 0)
{
dumpHistory();
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Dumped History at iter %d", iteration);
}
double[] derivative = function.derivativeAt(guess);
double value = function.valueAt(guess);
double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(function);
double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative);
// System.out.println(" Derivative is: "+DoubleArrays.toString(derivative, 100));
// DoubleArrays.assign(direction, derivative);
DoubleArrays.scale(direction, -1.0);
// System.out.println(" Looking in direction: "+DoubleArrays.toString(direction, 100));
if (iteration == 0)
lineSearcher.stepSizeMultiplier = initialStepSizeMultiplier;
else
lineSearcher.stepSizeMultiplier = stepSizeMultiplier;
double[] nextGuess = doLineSearch(function, lineSearcher, guess, direction);
double nextValue = function.valueAt(nextGuess);
double[] nextDerivative = function.derivativeAt(nextGuess);
if (printProgress) printProgress(iteration, nextValue);
if (iteration >= minIterations && converged(value, nextValue, tolerance))
{
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Converged.");
if (dumpHistoryBeforeConverge && !alreadyDumped)
{
dumpHistory();
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Dumping History. Doing Iteration Over");
alreadyDumped = true;
iteration--;
continue;
}
else
{
return nextGuess;
}
}
updateHistories(guess, nextGuess, derivative, nextDerivative);
guess = nextGuess;
value = nextValue;
derivative = nextDerivative;
if (iterCallbackFunction != null)
{
iterCallbackFunction.callback(guess, iteration, value, derivative);
}
}
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Stopped after " + iteration + " iterations.");
//Logger.logs("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
//System.err.println("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
return guess;
}
/**
* This is an entry point for subclasses
*
* @param function
* @param lineSearcher
* @param guess
* @param direction
* @return
*/
protected double[] doLineSearch(DifferentiableFunction function, BacktrackingLineSearcher lineSearcher, double[] guess, double[] direction)
{
return lineSearcher.minimize(function, guess, direction);
}
private void printProgress(int iteration, double nextValue)
{
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Iteration %d ended with value %.6f", iteration, nextValue);
}
protected boolean converged(double value, double nextValue, double tolerance)
{
if (value == nextValue) return true;
double valueChange = Math.abs(nextValue - value);
double valueAverage = Math.abs(nextValue + value + EPS) / 2.0;
if (valueChange / valueAverage < tolerance) return true;
return false;
}
protected void updateHistories(double[] guess, double[] nextGuess, double[] derivative, double[] nextDerivative)
{
double[] guessChange = DoubleArrays.addMultiples(nextGuess, 1.0, guess, -1.0);
double[] derivativeChange = DoubleArrays.addMultiples(nextDerivative, 1.0, derivative, -1.0);
pushOntoList(guessChange, inputDifferenceVectorList);
pushOntoList(derivativeChange, derivativeDifferenceVectorList);
}
private void pushOntoList(double[] vector, LinkedList<double[]> vectorList)
{
vectorList.addFirst(vector);
if (vectorList.size() > maxHistorySize) vectorList.removeLast();
}
private int historySize()
{
return inputDifferenceVectorList.size();
}
public void setMaxHistorySize(int maxHistorySize)
{
this.maxHistorySize = maxHistorySize;
}
private double[] getInputDifference(int num)
{
// 0 is previous, 1 is the one before that
return inputDifferenceVectorList.get(num);
}
private double[] getDerivativeDifference(int num)
{
return derivativeDifferenceVectorList.get(num);
}
private double[] getLastDerivativeDifference()
{
return derivativeDifferenceVectorList.getFirst();
}
private double[] getLastInputDifference()
{
return inputDifferenceVectorList.getFirst();
}
private double[] implicitMultiply(double[] initialInverseHessianDiagonal, double[] derivative)
{
double[] rho = new double[historySize()];
double[] alpha = new double[historySize()];
double[] right = DoubleArrays.clone(derivative);
// loop last backward
for (int i = historySize() - 1; i >= 0; i--)
{
double[] inputDifference = getInputDifference(i);
double[] derivativeDifference = getDerivativeDifference(i);
rho[i] = DoubleArrays.innerProduct(inputDifference, derivativeDifference);
if (rho[i] == 0.0) throw new RuntimeException("[LBFGSMinimizer.implicitMultiply]: Curvature problem.");
alpha[i] = DoubleArrays.innerProduct(inputDifference, right) / rho[i];
right = DoubleArrays.addMultiples(right, 1.0, derivativeDifference, -1.0 * alpha[i]);
}
double[] left = DoubleArrays.pointwiseMultiply(initialInverseHessianDiagonal, right);
for (int i = 0; i < historySize(); i++)
{
double[] inputDifference = getInputDifference(i);
double[] derivativeDifference = getDerivativeDifference(i);
double beta = DoubleArrays.innerProduct(derivativeDifference, left) / rho[i];
left = DoubleArrays.addMultiples(left, 1.0, inputDifference, alpha[i] - beta);
}
return left;
}
private double[] getInitialInverseHessianDiagonal(DifferentiableFunction function)
{
double scale = 1.0;
if (derivativeDifferenceVectorList.size() >= 1)
{
double[] lastDerivativeDifference = getLastDerivativeDifference();
double[] lastInputDifference = getLastInputDifference();
double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
scale = num / den;
}
return DoubleArrays.constantArray(scale, function.dimension());
}
/**
* User callback function to test or examine weights at the end of each
* iteration
*
* @param callbackFunction
* Will get called with the following args (double[]
* currentGuess, int iterDone, double value, double[] derivative)
* You don't have to read any or all of these.
*/
public void setIterationCallbackFunction(CallbackFunction callbackFunction)
{
this.iterCallbackFunction = callbackFunction;
}
public LBFGSMinimizer()
{
}
public LBFGSMinimizer(int maxIterations)
{
this.maxIterations = maxIterations;
}
public void dumpHistory()
{
inputDifferenceVectorList.clear();
derivativeDifferenceVectorList.clear();
}
}