/**
* Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.analytics.math.differentiation;
import org.apache.commons.lang.Validate;
import com.opengamma.analytics.math.MathException;
import com.opengamma.analytics.math.function.Function1D;
import com.opengamma.util.ArgumentChecker;
/**
* Differentiates a scalar function with respect to its argument using finite difference.
* <p>
* For a function $y = f(x)$ where $x$ and $y$ are scalars, this class produces
* a gradient function $g(x)$, i.e. a function that returns the gradient for
* each point $x$, where $g$ is the scalar $\frac{dy}{dx}$.
*/
public class ScalarFirstOrderDifferentiator implements Differentiator<Double, Double, Double> {
private static final double DEFAULT_EPS = 1e-5;
private static final double MIN_EPS = Math.sqrt(Double.MIN_NORMAL);
private static final FiniteDifferenceType DIFF_TYPE = FiniteDifferenceType.CENTRAL;
private final double _eps;
private final double _twoEps;
private final FiniteDifferenceType _differenceType;
/**
* Uses the default values of differencing type (central) and eps (10<sup>-5</sup>).
*/
public ScalarFirstOrderDifferentiator() {
this(DIFF_TYPE, DEFAULT_EPS);
}
/**
* Uses the default value of eps (10<sup>-5</sup>)
* @param differenceType The differencing type to be used in calculating the gradient function
*/
public ScalarFirstOrderDifferentiator(final FiniteDifferenceType differenceType) {
this(differenceType, DEFAULT_EPS);
}
/**
* @param differenceType {@link FiniteDifferenceType#FORWARD}, {@link FiniteDifferenceType#BACKWARD}, or {@link FiniteDifferenceType#CENTRAL}. In most situations,
* {@link FiniteDifferenceType#CENTRAL} is preferable. Not null
* @param eps The step size used to approximate the derivative. If this value is too small, the result will most likely be dominated by noise.
* Use around 10<sup>5</sup> times the domain size.
*/
public ScalarFirstOrderDifferentiator(final FiniteDifferenceType differenceType, final double eps) {
Validate.notNull(differenceType);
if (eps < MIN_EPS) {
throw new IllegalArgumentException("eps is too small. A good value is 1e-5*size of domain. The minimum value is " + MIN_EPS);
}
_differenceType = differenceType;
_eps = eps;
_twoEps = 2 * _eps;
}
@Override
public Function1D<Double, Double> differentiate(final Function1D<Double, Double> function) {
Validate.notNull(function);
switch (_differenceType) {
case FORWARD:
return new Function1D<Double, Double>() {
@SuppressWarnings("synthetic-access")
@Override
public Double evaluate(final Double x) {
Validate.notNull(x, "x");
return (function.evaluate(x + _eps) - function.evaluate(x)) / _eps;
}
};
case CENTRAL:
return new Function1D<Double, Double>() {
@SuppressWarnings("synthetic-access")
@Override
public Double evaluate(final Double x) {
Validate.notNull(x, "x");
return (function.evaluate(x + _eps) - function.evaluate(x - _eps)) / _twoEps;
}
};
case BACKWARD:
return new Function1D<Double, Double>() {
@SuppressWarnings("synthetic-access")
@Override
public Double evaluate(final Double x) {
Validate.notNull(x, "x");
return (function.evaluate(x) - function.evaluate(x - _eps)) / _eps;
}
};
default:
throw new IllegalArgumentException("Can only handle forward, backward and central differencing");
}
}
@Override
public Function1D<Double, Double> differentiate(final Function1D<Double, Double> function, final Function1D<Double, Boolean> domain) {
Validate.notNull(function);
Validate.notNull(domain);
final double[] wFwd = new double[] {-3. / _twoEps, 4. / _twoEps, -1. / _twoEps };
final double[] wCent = new double[] {-1. / _twoEps, 0., 1. / _twoEps };
final double[] wBack = new double[] {1. / _twoEps, -4. / _twoEps, 3. / _twoEps };
return new Function1D<Double, Double>() {
@SuppressWarnings("synthetic-access")
@Override
public Double evaluate(final Double x) {
Validate.notNull(x, "x");
ArgumentChecker.isTrue(domain.evaluate(x), "point {} is not in the function domain", x.toString());
final double[] y = new double[3];
double[] w;
if (!domain.evaluate(x + _eps)) {
if (!domain.evaluate(x - _eps)) {
throw new MathException("cannot get derivative at point " + x.toString());
}
y[0] = function.evaluate(x - _twoEps);
y[1] = function.evaluate(x - _eps);
y[2] = function.evaluate(x);
w = wBack;
} else {
if (!domain.evaluate(x - _eps)) {
y[0] = function.evaluate(x);
y[1] = function.evaluate(x + _eps);
y[2] = function.evaluate(x + _twoEps);
w = wFwd;
} else {
y[0] = function.evaluate(x - _eps);
y[2] = function.evaluate(x + _eps);
w = wCent;
}
}
double res = y[0] * w[0] + y[2] * w[2];
if (w[1] != 0) {
res += y[1] * w[1];
}
return res;
}
};
}
}