/** * Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.analytics.math.rootfinding.newton; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.primitives.Doubles; import com.opengamma.analytics.math.MathException; import com.opengamma.analytics.math.differentiation.VectorFieldFirstOrderDifferentiator; import com.opengamma.analytics.math.function.Function1D; import com.opengamma.analytics.math.matrix.DoubleMatrix1D; import com.opengamma.analytics.math.matrix.DoubleMatrix2D; import com.opengamma.analytics.math.matrix.MatrixAlgebra; import com.opengamma.analytics.math.matrix.OGMatrixAlgebra; import com.opengamma.analytics.math.rootfinding.VectorRootFinder; import com.opengamma.util.ArgumentChecker; /** * Base implementation for all Newton-Raphson style multi-dimensional root finding (i.e. using the Jacobian matrix as a basis for some iterative process) */ public class NewtonVectorRootFinder extends VectorRootFinder { private static final Logger s_logger = LoggerFactory.getLogger(NewtonVectorRootFinder.class); private static final double ALPHA = 1e-4; private static final double BETA = 1.5; private static final int FULL_RECALC_FREQ = 20; private final double _absoluteTol, _relativeTol; private final int _maxSteps; private final NewtonRootFinderDirectionFunction _directionFunction; private final NewtonRootFinderMatrixInitializationFunction _initializationFunction; private final NewtonRootFinderMatrixUpdateFunction _updateFunction; private final MatrixAlgebra _algebra = new OGMatrixAlgebra(); public NewtonVectorRootFinder(final double absoluteTol, final double relativeTol, final int maxSteps, final NewtonRootFinderDirectionFunction directionFunction, final NewtonRootFinderMatrixInitializationFunction initializationFunction, final NewtonRootFinderMatrixUpdateFunction updateFunction) { ArgumentChecker.notNegative(absoluteTol, "absolute tolerance"); ArgumentChecker.notNegative(relativeTol, "relative tolerance"); ArgumentChecker.notNegative(maxSteps, "maxSteps"); _absoluteTol = absoluteTol; _relativeTol = relativeTol; _maxSteps = maxSteps; _directionFunction = directionFunction; _initializationFunction = initializationFunction; _updateFunction = updateFunction; } @Override public DoubleMatrix1D getRoot(final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final DoubleMatrix1D startPosition) { final VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator(); return getRoot(function, jac.differentiate(function), startPosition); } /** *@param function a vector function (i.e. vector to vector) *@param jacobianFunction calculates the Jacobian * @param startPosition where to start the root finder for. Note if multiple roots exist which one if found (if at all) will depend on startPosition * @return the vector root of the collection of functions */ @SuppressWarnings("synthetic-access") public DoubleMatrix1D getRoot(final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final Function1D<DoubleMatrix1D, DoubleMatrix2D> jacobianFunction, final DoubleMatrix1D startPosition) { checkInputs(function, startPosition); final DataBundle data = new DataBundle(); final DoubleMatrix1D y = function.evaluate(startPosition); data.setX(startPosition); data.setY(y); data.setG0(_algebra.getInnerProduct(y, y)); DoubleMatrix2D estimate = _initializationFunction.getInitializedMatrix(jacobianFunction, startPosition); if (!getNextPosition(function, estimate, data)) { if (isConverged(data)) { return data.getX(); // this can happen if the starting position is the root } throw new MathException("Cannot work with this starting position. Please choose another point"); } int count = 0; int jacReconCount = 1; while (!isConverged(data)) { // Want to reset the Jacobian every so often even if backtracking is working if ((jacReconCount) % FULL_RECALC_FREQ == 0) { estimate = _initializationFunction.getInitializedMatrix(jacobianFunction, data.getX()); jacReconCount = 1; } else { estimate = _updateFunction.getUpdatedMatrix(jacobianFunction, data.getX(), data.getDeltaX(), data.getDeltaY(), estimate); jacReconCount++; } // if backtracking fails, could be that Jacobian estimate has drifted too far if (!getNextPosition(function, estimate, data)) { estimate = _initializationFunction.getInitializedMatrix(jacobianFunction, data.getX()); jacReconCount = 1; if (!getNextPosition(function, estimate, data)) { if (isConverged(data)) { return data.getX(); //non-standard exit. Cannot find an improvement from this position, so provided we are close enough to the root, exit. } String msg = "Failed to converge in backtracking, even after a Jacobian recalculation." + getErrorMessage(data, jacobianFunction); s_logger.info(msg); throw new MathException(msg); } } count++; if (count > _maxSteps) { throw new MathException("Failed to converge - maximum iterations of " + _maxSteps + " reached." + getErrorMessage(data, jacobianFunction)); } } return data.getX(); } private String getErrorMessage(final DataBundle data, final Function1D<DoubleMatrix1D, DoubleMatrix2D> jacobianFunction) { String msg = "Final position:" + data.getX() + "\nlast deltaX:" + data.getDeltaX() + "\n function value:" + data.getY() + "\nJacobian: \n" + jacobianFunction.evaluate(data.getX()); return msg; } private boolean getNextPosition(final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final DoubleMatrix2D estimate, final DataBundle data) { final DoubleMatrix1D p = _directionFunction.getDirection(estimate, data.getY()); if (data.getLambda0() < 1.0) { data.setLambda0(1.0); } else { data.setLambda0(data.getLambda0() * BETA); } updatePosition(p, function, data); final double g1 = data.getG1(); // the function is invalid at the new position, try to recover if (!Doubles.isFinite(g1)) { bisectBacktrack(p, function, data); } if (data.getG1() > data.getG0() / (1 + ALPHA * data.getLambda0())) { quadraticBacktrack(p, function, data); int count = 0; while (data.getG1() > data.getG0() / (1 + ALPHA * data.getLambda0())) { if (count > 5) { return false; } cubicBacktrack(p, function, data); count++; } } final DoubleMatrix1D deltaX = data.getDeltaX(); final DoubleMatrix1D deltaY = data.getDeltaY(); data.setG0(data.getG1()); data.setX((DoubleMatrix1D) _algebra.add(data.getX(), deltaX)); data.setY((DoubleMatrix1D) _algebra.add(data.getY(), deltaY)); return true; } protected void updatePosition(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final DataBundle data) { final double lambda0 = data.getLambda0(); final DoubleMatrix1D deltaX = (DoubleMatrix1D) _algebra.scale(p, -lambda0); final DoubleMatrix1D xNew = (DoubleMatrix1D) _algebra.add(data.getX(), deltaX); final DoubleMatrix1D yNew = function.evaluate(xNew); data.setDeltaX(deltaX); data.setDeltaY((DoubleMatrix1D) _algebra.subtract(yNew, data.getY())); data.setG2(data.getG1()); data.setG1(_algebra.getInnerProduct(yNew, yNew)); } private void bisectBacktrack(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final DataBundle data) { do { data.setLambda0(data.getLambda0() * 0.1); updatePosition(p, function, data); if (data.getLambda0() == 0.0) { throw new MathException("Failed to converge"); } } while (Double.isNaN(data.getG1()) || Double.isInfinite(data.getG1()) || Double.isNaN(data.getG2()) || Double.isInfinite(data.getG2())); } private void quadraticBacktrack(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final DataBundle data) { final double lambda0 = data.getLambda0(); final double g0 = data.getG0(); final double lambda = Math.max(0.01 * lambda0, g0 * lambda0 * lambda0 / (data.getG1() + g0 * (2 * lambda0 - 1))); data.swapLambdaAndReplace(lambda); updatePosition(p, function, data); } private void cubicBacktrack(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, DoubleMatrix1D> function, final DataBundle data) { double temp1, temp2, temp3, temp4, temp5; final double lambda0 = data.getLambda0(); final double lambda1 = data.getLambda1(); final double g0 = data.getG0(); temp1 = 1.0 / lambda0 / lambda0; temp2 = 1.0 / lambda1 / lambda1; temp3 = data.getG1() + g0 * (2 * lambda0 - 1.0); temp4 = data.getG2() + g0 * (2 * lambda1 - 1.0); temp5 = 1.0 / (lambda0 - lambda1); final double a = temp5 * (temp1 * temp3 - temp2 * temp4); final double b = temp5 * (-lambda1 * temp1 * temp3 + lambda0 * temp2 * temp4); double lambda = (-b + Math.sqrt(b * b + 6 * a * g0)) / 3 / a; lambda = Math.min(Math.max(lambda, 0.01 * lambda0), 0.75 * lambda1); // make sure new lambda is between 1% & 75% of old value data.swapLambdaAndReplace(lambda); updatePosition(p, function, data); } private boolean isConverged(final DataBundle data) { final DoubleMatrix1D deltaX = data.getDeltaX(); final DoubleMatrix1D x = data.getX(); final int n = deltaX.getNumberOfElements(); double diff, scale; for (int i = 0; i < n; i++) { diff = Math.abs(deltaX.getEntry(i)); scale = Math.abs(x.getEntry(i)); if (diff > _absoluteTol + scale * _relativeTol) { return false; } } return (Math.sqrt(data.getG0()) < _absoluteTol); } private static class DataBundle { private double _g0; private double _g1; private double _g2; private double _lambda0; private double _lambda1; private DoubleMatrix1D _deltaY; private DoubleMatrix1D _y; private DoubleMatrix1D _deltaX; private DoubleMatrix1D _x; public double getG0() { return _g0; } public double getG1() { return _g1; } public double getG2() { return _g2; } public double getLambda0() { return _lambda0; } public double getLambda1() { return _lambda1; } public DoubleMatrix1D getDeltaY() { return _deltaY; } public DoubleMatrix1D getY() { return _y; } public DoubleMatrix1D getDeltaX() { return _deltaX; } public DoubleMatrix1D getX() { return _x; } public void setG0(final double g0) { _g0 = g0; } public void setG1(final double g1) { _g1 = g1; } public void setG2(final double g2) { _g2 = g2; } public void setLambda0(final double lambda0) { _lambda0 = lambda0; } public void setDeltaY(final DoubleMatrix1D deltaY) { _deltaY = deltaY; } public void setY(final DoubleMatrix1D y) { _y = y; } public void setDeltaX(final DoubleMatrix1D deltaX) { _deltaX = deltaX; } public void setX(final DoubleMatrix1D x) { _x = x; } public void swapLambdaAndReplace(final double lambda0) { _lambda1 = _lambda0; _lambda0 = lambda0; } } }