/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
import java.util.Arrays;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.PositionChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.univariate.BracketFinder;
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
import org.apache.commons.math3.optim.univariate.SearchInterval;
import org.apache.commons.math3.optim.univariate.SimpleUnivariateValueChecker;
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;
import org.apache.commons.math3.util.FastMath;
import gdsc.core.utils.DoubleEquality;
/**
* Powell's algorithm.
* <p>
* The class is based on the org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer but updated to
* support: (a) convergence on the position; (b) convergence when using the original basis vectors; (c) support bounds
* checking on the current point within the optimisation space.
*/
public class CustomPowellOptimizer extends MultivariateOptimizer
{
/**
* Minimum relative tolerance.
*/
private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
/**
* Relative threshold.
*/
private final double relativeThreshold;
/**
* Absolute threshold.
*/
private final double absoluteThreshold;
/**
* Line search.
*/
private final LineSearch line;
/** Convergence tolerance on position */
private PositionChecker positionChecker = null;
/** Only allow convergence when using initial basis vectors */
private final boolean basisConvergence;
/** Allow custom basis search direction */
private double[] basis = null;
/** Flags to indicate if bounds are present */
private boolean isLower, isUpper;
private double[] lower, upper;
/**
* The initial step is used to construct the basis vectors for the search direction. By default the identity matrix
* is used for the search. The magnitude of each diagonal position can be set using this data object.
*/
public static class BasisStep implements OptimizationData
{
/** Initial step in each direction. */
private final double[] step;
/**
* @param step
* Initial step for the bracket search.
*/
public BasisStep(double[] step)
{
this.step = step;
}
/**
* Gets the initial step.
*
* @return the initial step.
*/
public double[] getStep()
{
return step;
}
}
/**
* This constructor allows to specify a user-defined convergence checker,
* in addition to the parameters that control the default convergence
* checking procedure. <br/>
* The internal line search tolerances are set to the square-root of their
* corresponding value in the multivariate optimizer.
*
* @param rel
* Relative threshold.
* @param abs
* Absolute threshold.
* @param checker
* Convergence checker.
* @param basisConvergence
* Only allow convergence when using initial basis vectors
* @throws NotStrictlyPositiveException
* if {@code abs <= 0}.
* @throws NumberIsTooSmallException
* if {@code rel < 2 * Math.ulp(1d)}.
*/
public CustomPowellOptimizer(double rel, double abs, ConvergenceChecker<PointValuePair> checker,
boolean basisConvergence)
{
this(rel, abs, FastMath.sqrt(rel), FastMath.sqrt(abs), checker, basisConvergence);
}
/**
* This constructor allows to specify a user-defined convergence checker,
* in addition to the parameters that control the default convergence
* checking procedure and the line search tolerances.
*
* @param rel
* Relative threshold for this optimizer.
* @param abs
* Absolute threshold for this optimizer.
* @param lineRel
* Relative threshold for the internal line search optimizer.
* @param lineAbs
* Absolute threshold for the internal line search optimizer.
* @param checker
* Convergence checker.
* @param basisConvergence
* Only allow convergence when using initial basis vectors. If true then the vectors are reset if they
* have been modified and the search continues.
* @throws NotStrictlyPositiveException
* if {@code abs <= 0}.
* @throws NumberIsTooSmallException
* if {@code rel < 2 * Math.ulp(1d)}.
*/
public CustomPowellOptimizer(double rel, double abs, double lineRel, double lineAbs,
ConvergenceChecker<PointValuePair> checker, boolean basisConvergence)
{
super(checker);
if (rel < MIN_RELATIVE_TOLERANCE)
{
throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
}
if (abs <= 0)
{
throw new NotStrictlyPositiveException(abs);
}
relativeThreshold = rel;
absoluteThreshold = abs;
this.basisConvergence = basisConvergence;
// Create the line search optimizer.
line = new LineSearch(lineRel, lineAbs);
}
/**
* The parameters control the default convergence checking procedure. <br/>
* The internal line search tolerances are set to the square-root of their
* corresponding value in the multivariate optimizer.
*
* @param rel
* Relative threshold.
* @param abs
* Absolute threshold.
* @throws NotStrictlyPositiveException
* if {@code abs <= 0}.
* @throws NumberIsTooSmallException
* if {@code rel < 2 * Math.ulp(1d)}.
*/
public CustomPowellOptimizer(double rel, double abs)
{
this(rel, abs, null, false);
}
/**
* Builds an instance with the default convergence checking procedure.
*
* @param rel
* Relative threshold.
* @param abs
* Absolute threshold.
* @param lineRel
* Relative threshold for the internal line search optimizer.
* @param lineAbs
* Absolute threshold for the internal line search optimizer.
* @throws NotStrictlyPositiveException
* if {@code abs <= 0}.
* @throws NumberIsTooSmallException
* if {@code rel < 2 * Math.ulp(1d)}.
*/
public CustomPowellOptimizer(double rel, double abs, double lineRel, double lineAbs)
{
this(rel, abs, lineRel, lineAbs, null, false);
}
/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize()
{
final GoalType goal = getGoalType();
final double[] guess = getStartPoint();
final int n = guess.length;
// Mark when we have modified the basis vectors
boolean nonBasis = false;
double[][] direc = createBasisVectors(n);
final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
//int resets = 0;
//PointValuePair solution = null;
//PointValuePair finalSolution = null;
//int solutionIter = 0, solutionEval = 0;
//double startValue = 0;
//try
//{
double[] x = guess;
// Ensure the point is within bounds
applyBounds(x);
double fVal = computeObjectiveValue(x);
//startValue = fVal;
double[] x1 = x.clone();
while (true)
{
incrementIterationCount();
final double fX = fVal;
double fX2 = 0;
double delta = 0;
int bigInd = 0;
for (int i = 0; i < n; i++)
{
fX2 = fVal;
final UnivariatePointValuePair optimum = line.search(x, direc[i]);
fVal = optimum.getValue();
x = newPoint(x, direc[i], optimum.getPoint());
if ((fX2 - fVal) > delta)
{
delta = fX2 - fVal;
bigInd = i;
}
}
boolean stop = false;
if (positionChecker != null)
{
// Check for convergence on the position
stop = positionChecker.converged(x1, x);
}
if (!stop)
{
// Default convergence check on value
//stop = 2 * (fX - fVal) <= (relativeThreshold * (FastMath.abs(fX) + FastMath.abs(fVal)) + absoluteThreshold);
// Check if we have improved from an impossible position
if (Double.isInfinite(fX) || Double.isNaN(fX))
{
if (Double.isInfinite(fVal) || Double.isNaN(fVal))
{
// Nowhere to go
stop = true;
}
// else: this is better as we now have a value, so continue
}
else
{
stop = DoubleEquality.almostEqualRelativeOrAbsolute(fX, fVal, relativeThreshold, absoluteThreshold);
}
}
final PointValuePair previous = new PointValuePair(x1, fX);
final PointValuePair current = new PointValuePair(x, fVal);
if (!stop && checker != null)
{ // User-defined stopping criteria.
stop = checker.converged(getIterations(), previous, current);
}
boolean reset = false;
if (stop)
{
// Only allow convergence using the basis vectors, i.e. we cannot move along any dimension
if (basisConvergence && nonBasis)
{
// Reset to the basis vectors and continue
reset = true;
//resets++;
}
else
{
//System.out.printf("Resets = %d\n", resets);
final PointValuePair answer;
if (goal == GoalType.MINIMIZE)
{
answer = (fVal < fX) ? current : previous;
}
else
{
answer = (fVal > fX) ? current : previous;
}
return answer;
// XXX Debugging
// Continue the algorithm to see how far it goes
//if (solution == null)
//{
// solution = answer;
// solutionIter = getIterations();
// solutionEval = getEvaluations();
//}
//finalSolution = answer;
}
}
if (reset)
{
direc = createBasisVectors(n);
nonBasis = false;
}
final double[] d = new double[n];
final double[] x2 = new double[n];
for (int i = 0; i < n; i++)
{
d[i] = x[i] - x1[i];
x2[i] = x[i] + d[i];
}
applyBounds(x2);
x1 = x.clone();
fX2 = computeObjectiveValue(x2);
// See if we can continue along the overall search direction to find a better value
if (fX > fX2)
{
// Check if:
// 1. The decrease along the average direction was not due to any single direction's decrease
// 2. There is a substantial second derivative along the average direction and we are close to
// it minimum
double t = 2 * (fX + fX2 - 2 * fVal);
double temp = fX - fVal - delta;
t *= temp * temp;
temp = fX - fX2;
t -= delta * temp * temp;
if (t < 0.0)
{
final UnivariatePointValuePair optimum = line.search(x, d);
fVal = optimum.getValue();
if (reset)
{
x = newPoint(x, d, optimum.getPoint());
continue;
}
else
{
final double[][] result = newPointAndDirection(x, d, optimum.getPoint());
x = result[0];
final int lastInd = n - 1;
direc[bigInd] = direc[lastInd];
direc[lastInd] = result[1];
nonBasis = true;
}
}
}
}
//}
//catch (RuntimeException e)
//{
// if (solution != null)
// {
// System.out.printf("Start %f : Initial %f (%d,%d) : Final %f (%d,%d) : %f\n", startValue,
// solution.getValue(), solutionIter, solutionEval, finalSolution.getValue(), getIterations(),
// getEvaluations(), DoubleEquality.relativeError(finalSolution.getValue(), solution.getValue()));
// return finalSolution;
// }
// throw e;
//}
}
private double[][] createBasisVectors(final int n)
{
double[][] direc = new double[n][n];
double[] step;
if (basis != null && basis.length == n)
{
step = basis;
}
else
{
step = new double[n];
Arrays.fill(step, 1);
}
for (int i = 0; i < n; i++)
{
direc[i][i] = step[i];
}
return direc;
}
/**
* Compute a new point (in the original space) and a new direction
* vector, resulting from the line search.
*
* @param p
* Point used in the line search.
* @param d
* Direction used in the line search.
* @param optimum
* Optimum found by the line search.
* @return a 2-element array containing the new point (at index 0) and
* the new direction (at index 1).
*/
private double[][] newPointAndDirection(final double[] p, final double[] d, final double optimum)
{
final int n = p.length;
final double[] nP = new double[n];
final double[] nD = new double[n];
for (int i = 0; i < n; i++)
{
nD[i] = d[i] * optimum;
nP[i] = p[i] + nD[i];
}
applyBounds(nP);
final double[][] result = new double[2][];
result[0] = nP;
result[1] = nD;
return result;
}
/**
* Compute a new point (in the original space) resulting from the line search.
*
* @param p
* Point used in the line search.
* @param d
* Direction used in the line search.
* @param optimum
* Optimum found by the line search.
* @return array containing the new point.
*/
private double[] newPoint(final double[] p, final double[] d, final double optimum)
{
final int n = p.length;
final double[] nP = new double[n];
for (int i = 0; i < n; i++)
{
nP[i] = p[i] + d[i] * optimum;
}
applyBounds(nP);
return nP;
}
/**
* Value that will pass the precondition check for {@link BrentOptimizer} but will not pass the convergence
* check, so that the custom checker
* will always decide when to stop the line search.
*/
private static final double REL_TOL_UNUSED;
static
{
REL_TOL_UNUSED = 2 * FastMath.ulp(1d);
}
/**
* Class for finding the minimum of the objective function along a given
* direction.
*/
private class LineSearch extends BrentOptimizer
{
/**
* Value that will pass the precondition check for {@link BrentOptimizer} but will not pass the convergence
* check, so that the custom checker
* will always decide when to stop the line search.
*/
private static final double ABS_TOL_UNUSED = Double.MIN_VALUE;
/**
* Automatic bracketing.
*/
private final BracketFinder bracket = new BracketFinder();
/**
* The "BrentOptimizer" default stopping criterion uses the tolerances
* to check the domain (point) values, not the function values.
* We thus create a custom checker to use function values.
*
* @param rel
* Relative threshold.
* @param abs
* Absolute threshold.
*/
LineSearch(double rel, double abs)
{
super(REL_TOL_UNUSED, ABS_TOL_UNUSED, new SimpleUnivariateValueChecker(rel, abs));
}
/**
* Find the minimum of the function {@code f(p + alpha * d)}.
*
* @param p
* Starting point.
* @param d
* Search direction.
* @return the optimum.
* @throws org.apache.commons.math3.exception.TooManyEvaluationsException
* if the number of evaluations is exceeded.
*/
public UnivariatePointValuePair search(final double[] p, final double[] d)
{
final int n = p.length;
final UnivariateFunction f = new UnivariateFunction()
{
final double[] x = new double[n];
public double value(double alpha)
{
for (int i = 0; i < n; i++)
{
x[i] = p[i] + alpha * d[i];
}
// Ensure the point is within bounds
applyBounds(x);
return CustomPowellOptimizer.this.computeObjectiveValue(x);
}
};
final GoalType goal = CustomPowellOptimizer.this.getGoalType();
bracket.search(f, goal, 0, 1);
// Passing "MAX_VALUE" as a dummy value because it is the enclosing
// class that counts the number of evaluations (and will eventually
// generate the exception).
return optimize(new MaxEval(Integer.MAX_VALUE), new UnivariateObjectiveFunction(f), goal,
new SearchInterval(bracket.getLo(), bracket.getHi(), bracket.getMid()));
}
}
/**
* Scans the list of (required and optional) optimization data that
* characterize the problem.
*
* @param optData
* Optimization data.
* The following data will be looked for:
* <ul>
* <li>{@link PositionChecker}</li>
* <li>{@link BasisStep}</li>
* </ul>
*/
@Override
protected void parseOptimizationData(OptimizationData... optData)
{
// Allow base class to register its own data.
super.parseOptimizationData(optData);
// The existing values (as set by the previous call) are reused if
// not provided in the argument list.
for (OptimizationData data : optData)
{
if (data instanceof PositionChecker)
{
positionChecker = (PositionChecker) data;
continue;
}
if (data instanceof BasisStep)
{
basis = ((BasisStep) data).getStep();
continue;
}
}
checkParameters();
}
/**
* @throws MathUnsupportedOperationException
* if bounds were passed to the {@link #optimize(OptimizationData[]) optimize} method and the lower is
* above the upper.
* @throws MathUnsupportedOperationException
* if the basis step passed to the {@link #optimize(OptimizationData[]) optimize} method is zero for any
* dimension
*/
private void checkParameters()
{
lower = getLowerBound();
upper = getUpperBound();
isLower = checkArray(lower, Double.NEGATIVE_INFINITY);
isUpper = checkArray(upper, Double.POSITIVE_INFINITY);
// Check that the upper bound is above the lower bound
if (isUpper && isLower)
{
for (int i = 0; i < lower.length; i++)
if (lower[i] > upper[i])
throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
}
if (basis != null)
{
for (double d : basis)
if (d == 0)
throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
}
}
/**
* Check if the array contains anything other than value
*
* @param array
* @param value
* @return True if the array has another value
*/
private boolean checkArray(double[] array, double value)
{
if (array == null)
return false;
for (int i = 0; i < array.length; i++)
if (value != array[i])
return true;
return false;
}
/**
* Check the point falls within the configured bounds truncating if necessary
*
* @param point
*/
private void applyBounds(double[] point)
{
if (isUpper)
{
for (int i = 0; i < point.length; i++)
if (point[i] > upper[i])
point[i] = upper[i];
}
if (isLower)
{
for (int i = 0; i < point.length; i++)
if (point[i] < lower[i])
point[i] = lower[i];
}
}
}