/* * (c) Copyright Christian P. Fries, Germany. All rights reserved. Contact: email@christian-fries.de. * * Created on 13.08.2004 */ package net.finmath.montecarlo.conditionalexpectation; import net.finmath.functions.LinearAlgebra; import net.finmath.stochastic.RandomVariableInterface; /** * A service that allows to estimate conditional expectation via regression. * In oder to estimate the conditional expectation, basis functions have to be * specified. * * The class can either estimate and predict the conditional expectation within * the same simulation (which will eventually introduce a small foresight bias) * or use a different simulation for estimation (using <code>basisFunctionsEstimator</code>) * to predict conditional expectation within another simulation * (using <code>basisFunctionsPredictor</code>). In the latter case, the * basis functions have to correspond to the same entities, however, generated in * different simulations (number of path, etc., may be different). * * @author Christian Fries */ public class MonteCarloConditionalExpectationRegression implements MonteCarloConditionalExpectation { private RandomVariableInterface[] basisFunctionsEstimator = null; private RandomVariableInterface[] basisFunctionsPredictor = null; /** * Creates a class for conditional expectation estimation. * * @param basisFunctions A vector of random variables to be used as basis functions. */ public MonteCarloConditionalExpectationRegression(RandomVariableInterface[] basisFunctions) { super(); this.basisFunctionsEstimator = basisFunctions; this.basisFunctionsPredictor = basisFunctions; } /** * Creates a class for conditional expectation estimation. * * @param basisFunctionsEstimator A vector of random variables to be used as basis functions for estimation. * @param basisFunctionsPredictor A vector of random variables to be used as basis functions for prediction. */ public MonteCarloConditionalExpectationRegression(RandomVariableInterface[] basisFunctionsEstimator, RandomVariableInterface[] basisFunctionsPredictor) { super(); this.basisFunctionsEstimator = basisFunctionsEstimator; this.basisFunctionsPredictor = basisFunctionsPredictor; } @Override public RandomVariableInterface getConditionalExpectation(RandomVariableInterface randomVariable) { // Get regression parameters x as the solution of XTX x = XT y double[] linearRegressionParameters = getLinearRegressionParameters(randomVariable); // Calculate estimate, i.e. X x RandomVariableInterface[] basisFunctions = getNonZeroBasisFunctions(basisFunctionsPredictor); RandomVariableInterface conditionalExpectation = basisFunctions[0].mult(linearRegressionParameters[0]); for(int i=1; i<basisFunctions.length; i++) { conditionalExpectation = conditionalExpectation.addProduct(basisFunctions[i], linearRegressionParameters[i]); } return conditionalExpectation; } /** * Return the solution x of XTX x = XT y for a given y. * @TODO Performance upon repeated call can be optimized by caching XTX. * * @param dependents The sample vector of the random variable y. * @return The solution x of XTX x = XT y. */ public double[] getLinearRegressionParameters(RandomVariableInterface dependents) { // Build XTX - the symmetric matrix consisting of the scalar products of the basis functions. RandomVariableInterface[] basisFunctions = getNonZeroBasisFunctions(basisFunctionsEstimator); double[][] XTX = new double[basisFunctions.length][basisFunctions.length]; for(int i=0; i<basisFunctions.length; i++) { for(int j=i; j<basisFunctions.length; j++) { XTX[i][j] = basisFunctions[i].getAverage(basisFunctions[j]); // Scalar product XTX[j][i] = XTX[i][j]; // Symmetric matrix } } // Build XTy - the projection of the dependents random variable on the basis functions. double[] XTy = new double[basisFunctions.length]; for(int i=0; i<basisFunctions.length; i++) { XTy[i] = dependents.getAverage(basisFunctions[i]); // Scalar product } // Solve X^T X x = X^T y - which gives us the regression coefficients x = linearRegressionParameters // @TODO A performance improvement is possible here by caching the SVD decomposition of the basis functions double[] linearRegressionParameters = LinearAlgebra.solveLinearEquationLeastSquare(XTX, XTy); return linearRegressionParameters; } private RandomVariableInterface[] getNonZeroBasisFunctions(RandomVariableInterface[] basisFunctions) { int numberOfNonZeroBasisFunctions = 0; for(int indexBasisFunction = 0; indexBasisFunction<basisFunctions.length; indexBasisFunction++) { if(basisFunctions[indexBasisFunction] != null) { numberOfNonZeroBasisFunctions++; } } RandomVariableInterface[] nonZerobasisFunctions = new RandomVariableInterface[numberOfNonZeroBasisFunctions]; int indexOfNonZeroBasisFunctions = 0; for (RandomVariableInterface basisFunction : basisFunctions) { if (basisFunction != null) { nonZerobasisFunctions[indexOfNonZeroBasisFunctions] = basisFunction; indexOfNonZeroBasisFunctions++; } } return nonZerobasisFunctions; } }