/** * Copyright (C) 2013 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.strata.math.impl.statistics.leastsquare; import static com.opengamma.strata.math.impl.interpolation.PenaltyMatrixGenerator.getPenaltyMatrix; import static org.testng.AssertJUnit.assertEquals; import java.util.function.Function; import org.apache.commons.math3.random.Well44497b; import org.testng.annotations.Test; import com.opengamma.strata.collect.array.DoubleArray; import com.opengamma.strata.collect.array.DoubleMatrix; import com.opengamma.strata.math.impl.matrix.CommonsMatrixAlgebra; import com.opengamma.strata.math.impl.matrix.MatrixAlgebra; /** * Test {@link NonLinearLeastSquareWithPenalty}. */ @Test public class NonLinearLeastSquareWithPenaltyTest { private static final MatrixAlgebra MA = new CommonsMatrixAlgebra(); private static NonLinearLeastSquareWithPenalty NLLSWP = new NonLinearLeastSquareWithPenalty(); static int N_SWAPS = 8; public void linearTest() { boolean print = false; if (print) { System.out.println("NonLinearLeastSquareWithPenaltyTest.linearTest"); } int nWeights = 20; int diffOrder = 2; double lambda = 100.0; DoubleMatrix penalty = (DoubleMatrix) MA.scale(getPenaltyMatrix(nWeights, diffOrder), lambda); int[] onIndex = new int[] {1, 4, 11, 12, 15, 17}; double[] obs = new double[] {0, 1.0, 1.0, 1.0, 0.0, 0.0}; int n = onIndex.length; Function<DoubleArray, DoubleArray> func = new Function<DoubleArray, DoubleArray>() { @Override public DoubleArray apply(DoubleArray x) { return DoubleArray.of(n, i -> x.get(onIndex[i])); } }; Function<DoubleArray, DoubleMatrix> jac = new Function<DoubleArray, DoubleMatrix>() { @Override public DoubleMatrix apply(DoubleArray x) { return DoubleMatrix.of( n, nWeights, (i, j) -> j == onIndex[i] ? 1d : 0d); } }; Well44497b random = new Well44497b(0L); DoubleArray start = DoubleArray.of(nWeights, i -> random.nextDouble()); LeastSquareWithPenaltyResults lsRes = NLLSWP.solve( DoubleArray.copyOf(obs), DoubleArray.filled(n, 0.01), func, jac, start, penalty); if (print) { System.out.println("chi2: " + lsRes.getChiSq()); System.out.println(lsRes.getFitParameters()); } for (int i = 0; i < n; i++) { assertEquals(obs[i], lsRes.getFitParameters().get(onIndex[i]), 0.01); } double expPen = 20.87912357454752; assertEquals(expPen, lsRes.getPenalty(), 1e-9); } }