/**
* Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.strata.math.impl.statistics.leastsquare;
import static org.testng.AssertJUnit.assertEquals;
import static org.testng.AssertJUnit.assertTrue;
import java.util.function.Function;
import org.testng.annotations.Test;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.math.impl.function.ParameterizedFunction;
import com.opengamma.strata.math.impl.linearalgebra.LUDecompositionCommons;
import com.opengamma.strata.math.impl.linearalgebra.LUDecompositionResult;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.OGMatrixAlgebra;
import com.opengamma.strata.math.impl.statistics.distribution.NormalDistribution;
import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.MersenneTwister64;
/**
* Test.
*/
@Test
public class NonLinearLeastSquareTest {
private static final NormalDistribution NORMAL = new NormalDistribution(0, 1.0, new MersenneTwister64(MersenneTwister.DEFAULT_SEED));
private static final DoubleArray X;
private static final DoubleArray Y;
private static final DoubleArray SIGMA;
private static final NonLinearLeastSquare LS;
private static final Function<Double, Double> TARGET = new Function<Double, Double>() {
@Override
public Double apply(final Double x) {
return Math.sin(x);
}
};
private static final Function<DoubleArray, DoubleArray> FUNCTION = new Function<DoubleArray, DoubleArray>() {
@SuppressWarnings("synthetic-access")
@Override
public DoubleArray apply(final DoubleArray a) {
ArgChecker.isTrue(a.size() == 4, "four parameters");
final int n = X.size();
final double[] res = new double[n];
for (int i = 0; i < n; i++) {
res[i] = a.get(0) * Math.sin(a.get(1) * X.get(i) + a.get(2)) + a.get(3);
}
return DoubleArray.copyOf(res);
}
};
private static final ParameterizedFunction<Double, DoubleArray, Double> PARAM_FUNCTION = new ParameterizedFunction<Double, DoubleArray, Double>() {
@Override
public Double evaluate(final Double x, final DoubleArray a) {
ArgChecker.isTrue(a.size() == getNumberOfParameters(), "four parameters");
return a.get(0) * Math.sin(a.get(1) * x + a.get(2)) + a.get(3);
}
@Override
public int getNumberOfParameters() {
return 4;
}
};
private static final ParameterizedFunction<Double, DoubleArray, DoubleArray> PARAM_GRAD = new ParameterizedFunction<Double, DoubleArray, DoubleArray>() {
@Override
public DoubleArray evaluate(final Double x, final DoubleArray a) {
ArgChecker.isTrue(a.size() == getNumberOfParameters(), "four parameters");
final double temp1 = Math.sin(a.get(1) * x + a.get(2));
final double temp2 = Math.cos(a.get(1) * x + a.get(2));
final double[] res = new double[4];
res[0] = temp1;
res[2] = a.get(0) * temp2;
res[1] = x * res[2];
res[3] = 1.0;
return DoubleArray.copyOf(res);
}
@Override
public int getNumberOfParameters() {
return 4;
}
};
private static final Function<DoubleArray, DoubleMatrix> GRAD = new Function<DoubleArray, DoubleMatrix>() {
@SuppressWarnings("synthetic-access")
@Override
public DoubleMatrix apply(final DoubleArray a) {
final int n = X.size();
final int m = a.size();
final double[][] res = new double[n][m];
for (int i = 0; i < n; i++) {
final DoubleArray temp = PARAM_GRAD.evaluate(X.get(i), a);
ArgChecker.isTrue(m == temp.size());
for (int j = 0; j < m; j++) {
res[i][j] = temp.get(j);
}
}
return DoubleMatrix.copyOf(res);
}
};
static {
X = DoubleArray.of(20, i -> -Math.PI + i * Math.PI / 10);
Y = DoubleArray.of(20, i -> TARGET.apply(X.get(i)));
SIGMA = DoubleArray.of(20, i -> 0.1 * Math.exp(Math.abs(X.get(i)) / Math.PI));
LS = new NonLinearLeastSquare();
}
public void solveExactTest() {
final DoubleArray start = DoubleArray.of(1.2, 0.8, -0.2, -0.3);
LeastSquareResults result = LS.solve(X, Y, SIGMA, PARAM_FUNCTION, PARAM_GRAD, start);
assertEquals(0.0, result.getChiSq(), 1e-8);
assertEquals(1.0, result.getFitParameters().get(0), 1e-8);
assertEquals(1.0, result.getFitParameters().get(1), 1e-8);
assertEquals(0.0, result.getFitParameters().get(2), 1e-8);
assertEquals(0.0, result.getFitParameters().get(3), 1e-8);
result = LS.solve(X, Y, SIGMA.get(0), PARAM_FUNCTION, PARAM_GRAD, start);
assertEquals(0.0, result.getChiSq(), 1e-8);
assertEquals(1.0, result.getFitParameters().get(0), 1e-8);
assertEquals(1.0, result.getFitParameters().get(1), 1e-8);
assertEquals(0.0, result.getFitParameters().get(2), 1e-8);
assertEquals(0.0, result.getFitParameters().get(3), 1e-8);
result = LS.solve(X, Y, PARAM_FUNCTION, PARAM_GRAD, start);
assertEquals(0.0, result.getChiSq(), 1e-8);
assertEquals(1.0, result.getFitParameters().get(0), 1e-8);
assertEquals(1.0, result.getFitParameters().get(1), 1e-8);
assertEquals(0.0, result.getFitParameters().get(2), 1e-8);
assertEquals(0.0, result.getFitParameters().get(3), 1e-8);
}
public void solveExactTest2() {
final DoubleArray start = DoubleArray.of(0.2, 1.8, 0.2, 0.3);
final LeastSquareResults result = LS.solve(Y, SIGMA, FUNCTION, start);
assertEquals(0.0, result.getChiSq(), 1e-8);
assertEquals(1.0, result.getFitParameters().get(0), 1e-8);
assertEquals(1.0, result.getFitParameters().get(1), 1e-8);
assertEquals(0.0, result.getFitParameters().get(2), 1e-8);
assertEquals(0.0, result.getFitParameters().get(3), 1e-8);
}
public void solveExactWithoutGradientTest() {
final DoubleArray start = DoubleArray.of(1.2, 0.8, -0.2, -0.3);
final NonLinearLeastSquare ls = new NonLinearLeastSquare();
final LeastSquareResults result = ls.solve(X, Y, SIGMA, PARAM_FUNCTION, start);
assertEquals(0.0, result.getChiSq(), 1e-8);
assertEquals(1.0, result.getFitParameters().get(0), 1e-8);
assertEquals(1.0, result.getFitParameters().get(1), 1e-8);
assertEquals(0.0, result.getFitParameters().get(2), 1e-8);
assertEquals(0.0, result.getFitParameters().get(3), 1e-8);
}
public void solveRandomNoiseTest() {
final MatrixAlgebra ma = new OGMatrixAlgebra();
final double[] y = new double[20];
for (int i = 0; i < 20; i++) {
y[i] = Y.get(i) + SIGMA.get(i) * NORMAL.nextRandom();
}
final DoubleArray start = DoubleArray.of(0.7, 1.4, 0.2, -0.3);
final NonLinearLeastSquare ls = new NonLinearLeastSquare();
final LeastSquareResults res = ls.solve(X, DoubleArray.copyOf(y), SIGMA, PARAM_FUNCTION, PARAM_GRAD, start);
final double chiSqDoF = res.getChiSq() / 16;
assertTrue(chiSqDoF > 0.25);
assertTrue(chiSqDoF < 3.0);
final DoubleArray trueValues = DoubleArray.of(1, 1, 0, 0);
final DoubleArray delta = (DoubleArray) ma.subtract(res.getFitParameters(), trueValues);
final LUDecompositionCommons decmp = new LUDecompositionCommons();
final LUDecompositionResult decmpRes = decmp.apply(res.getCovariance());
final DoubleMatrix invCovariance = decmpRes.solve(DoubleMatrix.identity(4));
double z = ma.getInnerProduct(delta, ma.multiply(invCovariance, delta));
z = Math.sqrt(z);
assertTrue(z < 3.0);
}
public void smallPertubationTest() {
final MatrixAlgebra ma = new OGMatrixAlgebra();
final double[] dy = new double[20];
for (int i = 0; i < 20; i++) {
dy[i] = 0.1 * SIGMA.get(i) * NORMAL.nextRandom();
}
final DoubleArray deltaY = DoubleArray.copyOf(dy);
final DoubleArray solution = DoubleArray.of(1.0, 1.0, 0.0, 0.0);
final NonLinearLeastSquare ls = new NonLinearLeastSquare();
final DoubleMatrix res = ls.calInverseJacobian(SIGMA, FUNCTION, GRAD, solution);
final DoubleArray deltaParms = (DoubleArray) ma.multiply(res, deltaY);
final DoubleArray y = (DoubleArray) ma.add(Y, deltaY);
final LeastSquareResults lsRes = ls.solve(X, y, SIGMA, PARAM_FUNCTION, PARAM_GRAD, solution);
final DoubleArray trueDeltaParms = (DoubleArray) ma.subtract(lsRes.getFitParameters(), solution);
assertEquals(trueDeltaParms.get(0), deltaParms.get(0), 5e-5);
assertEquals(trueDeltaParms.get(1), deltaParms.get(1), 5e-5);
assertEquals(trueDeltaParms.get(2), deltaParms.get(2), 5e-5);
assertEquals(trueDeltaParms.get(3), deltaParms.get(3), 5e-5);
}
}