/** * Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.strata.math.impl.statistics.leastsquare; import static org.testng.AssertJUnit.assertEquals; import static org.testng.AssertJUnit.assertTrue; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Function; import org.testng.annotations.Test; import com.opengamma.strata.collect.ArgChecker; import com.opengamma.strata.collect.array.DoubleArray; import com.opengamma.strata.math.impl.interpolation.BasisFunctionAggregation; import com.opengamma.strata.math.impl.interpolation.BasisFunctionGenerator; import com.opengamma.strata.math.impl.interpolation.PSplineFitter; import com.opengamma.strata.math.impl.statistics.distribution.NormalDistribution; import cern.jet.random.engine.MersenneTwister; import cern.jet.random.engine.MersenneTwister64; import cern.jet.random.engine.RandomEngine; /** * Test. */ @SuppressWarnings("deprecation") @Test public class GeneralizedLeastSquareTest { private static boolean PRINT = false; protected static final RandomEngine RANDOM = new MersenneTwister64(MersenneTwister.DEFAULT_SEED); private static final NormalDistribution NORMAL = new NormalDistribution(0, 1.0, RANDOM); private static final double[] WEIGHTS = new double[] {1.0, -0.5, 2.0, 0.23, 1.45 }; private static final Double[] X; private static final double[] Y; private static final double[] SIGMA; private static final List<DoubleArray> X_TRIG; private static final List<Double> Y_TRIG; private static final List<Double> SIGMA_TRIG; private static final List<Double> SIGMA_COS_EXP; private static final List<double[]> X_SIN_EXP; private static final List<Double> Y_SIN_EXP; private static final List<Function<Double, Double>> SIN_FUNCTIONS; private static final Function<Double, Double> TEST_FUNCTION; private static final List<Function<Double, Double>> BASIS_FUNCTIONS; private static final List<Function<double[], Double>> BASIS_FUNCTIONS_2D; private static Function<double[], Double> SIN_EXP_FUNCTION; private static final List<Function<DoubleArray, Double>> VECTOR_TRIG_FUNCTIONS; private static final Function<DoubleArray, Double> VECTOR_TEST_FUNCTION; static { SIN_FUNCTIONS = new ArrayList<>(); for (int i = 0; i < WEIGHTS.length; i++) { final int k = i; final Function<Double, Double> func = new Function<Double, Double>() { @Override public Double apply(final Double x) { return Math.sin((2 * k + 1) * x); } }; SIN_FUNCTIONS.add(func); } TEST_FUNCTION = new BasisFunctionAggregation<>(SIN_FUNCTIONS, WEIGHTS); VECTOR_TRIG_FUNCTIONS = new ArrayList<>(); for (int i = 0; i < WEIGHTS.length; i++) { final int k = i; final Function<DoubleArray, Double> func = new Function<DoubleArray, Double>() { @Override public Double apply(final DoubleArray x) { ArgChecker.isTrue(x.size() == 2); return Math.sin((2 * k + 1) * x.get(0)) * Math.cos((2 * k + 1) * x.get(1)); } }; VECTOR_TRIG_FUNCTIONS.add(func); } VECTOR_TEST_FUNCTION = new BasisFunctionAggregation<>(VECTOR_TRIG_FUNCTIONS, WEIGHTS); SIN_EXP_FUNCTION = new Function<double[], Double>() { @Override public Double apply(final double[] x) { return Math.sin(Math.PI * x[0] / 10.0) * Math.exp(-x[1] / 5.); } }; final int n = 10; X = new Double[n]; Y = new double[n]; SIGMA = new double[n]; X_TRIG = new ArrayList<>(); Y_TRIG = new ArrayList<>(); SIGMA_TRIG = new ArrayList<>(); for (int i = 0; i < n; i++) { X[i] = i / 5.0; Y[i] = TEST_FUNCTION.apply(X[i]); final double[] temp = new double[2]; temp[0] = 2.0 * RANDOM.nextDouble(); temp[1] = 2.0 * RANDOM.nextDouble(); X_TRIG.add(DoubleArray.copyOf(temp)); Y_TRIG.add(VECTOR_TEST_FUNCTION.apply(X_TRIG.get(i))); SIGMA[i] = 0.01; SIGMA_TRIG.add(0.01); } SIGMA_COS_EXP = new ArrayList<>(); X_SIN_EXP = new ArrayList<>(); Y_SIN_EXP = new ArrayList<>(); for (int i = 0; i < 20; i++) { final double[] temp = new double[2]; temp[0] = 10.0 * RANDOM.nextDouble(); temp[1] = 10.0 * RANDOM.nextDouble(); X_SIN_EXP.add(temp); Y_SIN_EXP.add(SIN_EXP_FUNCTION.apply(X_SIN_EXP.get(i))); SIGMA_COS_EXP.add(0.01); } final BasisFunctionGenerator generator = new BasisFunctionGenerator(); BASIS_FUNCTIONS = generator.generateSet(0.0, 2.0, 20, 3); BASIS_FUNCTIONS_2D = generator.generateSet(new double[] {0.0, 0.0 }, new double[] {10.0, 10.0 }, new int[] {10, 10 }, new int[] {3, 3 }); } public void testPerfectFit() { final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final LeastSquareResults results = gls.solve(X, Y, SIGMA, SIN_FUNCTIONS); assertEquals(0.0, results.getChiSq(), 1e-8); final DoubleArray w = results.getFitParameters(); for (int i = 0; i < WEIGHTS.length; i++) { assertEquals(WEIGHTS[i], w.get(i), 1e-8); } } public void testPerfectFitVector() { final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final LeastSquareResults results = gls.solve(X_TRIG, Y_TRIG, SIGMA_TRIG, VECTOR_TRIG_FUNCTIONS); assertEquals(0.0, results.getChiSq(), 1e-8); final DoubleArray w = results.getFitParameters(); for (int i = 0; i < WEIGHTS.length; i++) { assertEquals(WEIGHTS[i], w.get(i), 1e-8); } } public void testFit() { final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final double[] y = new double[Y.length]; for (int i = 0; i < Y.length; i++) { y[i] = Y[i] + SIGMA[i] * NORMAL.nextRandom(); } final LeastSquareResults results = gls.solve(X, y, SIGMA, SIN_FUNCTIONS); assertTrue(results.getChiSq() < 3 * Y.length); } public void testBSplineFit() { final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final LeastSquareResults results = gls.solve(X, Y, SIGMA, BASIS_FUNCTIONS); final Function<Double, Double> spline = new BasisFunctionAggregation<>(BASIS_FUNCTIONS, results.getFitParameters().toArray()); assertEquals(0.0, results.getChiSq(), 1e-12); assertEquals(-0.023605293, spline.apply(0.5), 1e-8); if (PRINT) { System.out.println("Chi^2:\t" + results.getChiSq()); System.out.println("weights:\t" + results.getFitParameters()); for (int i = 0; i < 101; i++) { final double x = 0 + i * 2.0 / 100.0; System.out.println(x + "\t" + spline.apply(x)); } for (int i = 0; i < X.length; i++) { System.out.println(X[i] + "\t" + Y[i]); } } } public void testBSplineFit2D() { final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final LeastSquareResults results = gls.solve(X_SIN_EXP, Y_SIN_EXP, SIGMA_COS_EXP, BASIS_FUNCTIONS_2D); final Function<double[], Double> spline = new BasisFunctionAggregation<>(BASIS_FUNCTIONS_2D, results.getFitParameters().toArray()); assertEquals(0.0, results.getChiSq(), 1e-16); assertEquals(0.05161579, spline.apply(new double[] {4, 3 }), 1e-8); /* * Print out function for debugging */ if (PRINT) { System.out.println("Chi^2:\t" + results.getChiSq()); System.out.println("weights:\t" + results.getFitParameters()); final double[] x = new double[2]; for (int i = 0; i < 101; i++) { x[0] = 0 + i * 10.0 / 100.0; System.out.print("\t" + x[0]); } System.out.print("\n"); for (int i = 0; i < 101; i++) { x[0] = -0. + i * 10 / 100.0; System.out.print(x[0]); for (int j = 0; j < 101; j++) { x[1] = -0.0 + j * 10.0 / 100.0; final double y = spline.apply(x); System.out.print("\t" + y); } System.out.print("\n"); } } } public void testPSplineFit() { final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final GeneralizedLeastSquareResults<Double> results = gls.solve(X, Y, SIGMA, BASIS_FUNCTIONS, 1000.0, 2); final Function<Double, Double> spline = results.getFunction(); assertEquals(2225.7, results.getChiSq(), 1e-1); assertEquals(-0.758963811327287, spline.apply(1.1), 1e-8); /* * Print out function for debugging */ if (PRINT) { System.out.println("Chi^2:\t" + results.getChiSq()); System.out.println("weights:\t" + results.getFitParameters()); for (int i = 0; i < 101; i++) { final double x = 0 + i * 2.0 / 100.0; System.out.println(x + "\t" + spline.apply(x)); } for (int i = 0; i < X.length; i++) { System.out.println(X[i] + "\t" + Y[i]); } } } public void testPSplineFit2() { final BasisFunctionGenerator generator = new BasisFunctionGenerator(); List<Function<Double, Double>> basisFuncs = generator.generateSet(0, 12, 100, 3); List<Function<Double, Double>> basisFuncsLog = generator.generateSet(-5, 3, 100, 3); final GeneralizedLeastSquare gls = new GeneralizedLeastSquare(); final double[] xData = new double[] {7. / 365, 14 / 365., 21 / 365., 1 / 12., 3 / 12., 0.5, 0.75, 1, 5, 10 }; final double[] yData = new double[] {0.972452371, 0.749039802, 0.759792085, 0.714206462, 0.604446956, 0.517955313, 0.474807307, 0.443532132, 0.2404755, 0.197128583, }; final int n = xData.length; final double[] lnX = new double[n]; final double[] yData2 = new double[n]; for (int i = 0; i < n; i++) { lnX[i] = Math.log(xData[i]); yData2[i] = yData[i] * yData[i] * xData[i]; } final double[] sigma = new double[n]; Arrays.fill(sigma, 0.01); final GeneralizedLeastSquareResults<Double> results = gls.solve(xData, yData, sigma, basisFuncs, 1000.0, 2); final Function<Double, Double> spline = results.getFunction(); final GeneralizedLeastSquareResults<Double> resultsLog = gls.solve(lnX, yData, sigma, basisFuncsLog, 1000.0, 2); final Function<Double, Double> splineLog = resultsLog.getFunction(); final GeneralizedLeastSquareResults<Double> resultsVar = gls.solve(xData, yData2, sigma, basisFuncs, 1000.0, 2); final Function<Double, Double> splineVar = resultsVar.getFunction(); final GeneralizedLeastSquareResults<Double> resultsVarLog = gls.solve(lnX, yData2, sigma, basisFuncsLog, 1000.0, 2); final Function<Double, Double> splineVarLog = resultsVarLog.getFunction(); if (PRINT) { System.out.println("Chi^2:\t" + results.getChiSq()); System.out.println("weights:\t" + results.getFitParameters()); for (int i = 0; i < 101; i++) { final double logX = -5 + 8 * i / 100.; final double x = Math.exp(logX); System.out.println(x + "\t" + +logX + "\t" + spline.apply(x) + "\t" + splineLog.apply(logX) + "\t" + splineVar.apply(x) + "\t" + splineVarLog.apply(logX)); } for (int i = 0; i < n; i++) { System.out.println(lnX[i] + "\t" + yData[i]); } } } public void testPSplineFit2D() { final PSplineFitter psf = new PSplineFitter(); final GeneralizedLeastSquareResults<double[]> results = psf.solve(X_SIN_EXP, Y_SIN_EXP, SIGMA_COS_EXP, new double[] {0.0, 0.0 }, new double[] {10.0, 10.0 }, new int[] {10, 10 }, new int[] {3, 3 }, new double[] {0.001, 0.001 }, new int[] {3, 3 }); assertEquals(0.0, results.getChiSq(), 1e-9); final Function<double[], Double> spline = results.getFunction(); assertEquals(0.5333876489112092, spline.apply(new double[] {4, 3 }), 1e-8); /* * Print out function for debugging */ if (PRINT) { System.out.println("Chi^2:\t" + results.getChiSq()); System.out.println("weights:\t" + results.getFitParameters()); final double[] x = new double[2]; for (int i = 0; i < 101; i++) { x[0] = 0 + i * 10.0 / 100.0; System.out.print("\t" + x[0]); } System.out.print("\n"); for (int i = 0; i < 101; i++) { x[0] = -0. + i * 10 / 100.0; System.out.print(x[0]); for (int j = 0; j < 101; j++) { x[1] = -0.0 + j * 10.0 / 100.0; final double y = spline.apply(x); System.out.print("\t" + y); } System.out.print("\n"); } } } }