/** * Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.analytics.financial.model.finitedifference; import org.apache.commons.lang.Validate; import cern.jet.random.engine.MersenneTwister; import cern.jet.random.engine.MersenneTwister64; import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.BlackFunctionData; import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.BlackPriceFunction; import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.CEVFunctionData; import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.CEVPriceFunction; import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.EuropeanVanillaOption; import com.opengamma.analytics.math.function.Function1D; /** * */ public class MarkovChain { private final double _vol1; private final double _vol2; private final double _lambda12; private final double _lambda21; private final double _probState1; @SuppressWarnings("unused") private final double _pi1; private final MersenneTwister _rand; public MarkovChain(final double vol1, final double vol2, final double lambda12, final double lambda21, final double probState1) { this(vol1, vol2, lambda12, lambda21, probState1, MersenneTwister.DEFAULT_SEED); } public MarkovChain(final double vol1, final double vol2, final double lambda12, final double lambda21, final double probState1, final int seed) { Validate.isTrue(vol1 >= 0); Validate.isTrue(vol2 >= 0); Validate.isTrue(lambda12 >= 0); Validate.isTrue(lambda21 >= 0); Validate.isTrue(probState1 >= 0 && probState1 <= 1.0); _vol1 = vol1; _vol2 = vol2; _lambda12 = lambda12; _lambda21 = lambda21; _probState1 = probState1; _pi1 = lambda21 / (lambda12 + lambda21); _rand = new MersenneTwister64(seed); } public double price(final double forward, final double df, final double strike, final double timeToExiry, final double[] sigmas) { final EuropeanVanillaOption option = new EuropeanVanillaOption(strike, timeToExiry, true); final BlackPriceFunction func = new BlackPriceFunction(); final Function1D<BlackFunctionData, Double> priceFunc = func.getPriceFunction(option); double sum = 0; for (final double sigma : sigmas) { final BlackFunctionData data = new BlackFunctionData(forward, df, sigma); sum += priceFunc.evaluate(data); } return sum / sigmas.length; } public double priceCEV(final double forward, final double df, final double strike, final double timeToExiry, final double beta, final double[] sigmas) { final EuropeanVanillaOption option = new EuropeanVanillaOption(strike, timeToExiry, true); final CEVPriceFunction func = new CEVPriceFunction(); final Function1D<CEVFunctionData, Double> priceFunc = func.getPriceFunction(option); double sum = 0; for (final double sigma : sigmas) { final CEVFunctionData data = new CEVFunctionData(forward, df, sigma, beta); sum += priceFunc.evaluate(data); } return sum / sigmas.length; } public double[][] price(final double[] forwards, final double[] df, final double[] strike, final double[] expiries, final double[][] sigmas) { final int nTime = forwards.length; final int nStrikes = strike.length; Validate.isTrue(nTime == df.length); Validate.isTrue(nTime == expiries.length); Validate.isTrue(nTime == sigmas.length); final BlackPriceFunction func = new BlackPriceFunction(); final double[][] price = new double[nTime][nStrikes]; double t, k; for (int j = 0; j < nTime; j++) { t = expiries[j]; final double[] tSigmas = sigmas[j]; for (int i = 0; i < nStrikes; i++) { k = strike[i]; final EuropeanVanillaOption option = new EuropeanVanillaOption(k, t, true); final Function1D<BlackFunctionData, Double> priceFunc = func.getPriceFunction(option); double sum = 0; for (final double sigma : tSigmas) { final BlackFunctionData data = new BlackFunctionData(forwards[j], df[j], sigma); sum += priceFunc.evaluate(data); } price[j][i] = sum / tSigmas.length; } } return price; } public double[] getMoments(@SuppressWarnings("unused") final double t, final double[] sigmas) { double sum1 = 0; double sum2 = 0; double sum3 = 0; for (final double sigma : sigmas) { final double var = sigma * sigma; sum1 += var; sum2 += var * var; sum3 += var * var * var; } final int n = sigmas.length; final double m1 = sum1 / n; final double m2 = (sum2 - n * m1 * m1) / (n - 1); final double m3 = (sum3 - 3 * m1 * sum2 + 2 * n * m1 * m1 * m1) / n; return new double[] {m1, m2, m3 }; } public double[] simulate(final double timeToExpiry, final int n) { double vol, lambda, tau; final double[] vols = new double[n]; for (int i = 0; i < n; i++) { boolean state1 = _probState1 > _rand.nextDouble(); double t = 0; double var = 0.0; while (t < timeToExpiry) { if (state1) { vol = _vol1; lambda = _lambda12; } else { vol = _vol2; lambda = _lambda21; } tau = -Math.log(_rand.nextDouble()) / lambda; if (t + tau < timeToExpiry) { var += tau * vol * vol; state1 = !state1; } else { var += (timeToExpiry - t) * vol * vol; } t += tau; } vols[i] = Math.sqrt(var / timeToExpiry); } return vols; } public double[][] simulate(final double[] expiries, final int n) { return simulate(expiries, n, 0.0, 1.0); } public double[][] simulate(final double[] expiries, final int n, final double a, final double b) { Validate.notNull(expiries); Validate.isTrue(b > a, "need b > a"); Validate.isTrue(a >= 0.0, "Nedd a >= 0.0"); Validate.isTrue(b <= 1.0, "Nedd b <= 1.0"); final int m = expiries.length; Validate.isTrue(m > 0); for (int j = 1; j < m; j++) { Validate.isTrue(expiries[j] > expiries[j - 1]); } double vol, lambda, tau; final double[][] vols = new double[m][n]; for (int i = 0; i < n; i++) { int j = 0; boolean state1 = _probState1 > _rand.nextDouble(); double t = 0; double var = 0.0; while (j < m && t < expiries[m - 1]) { if (state1) { vol = _vol1; lambda = _lambda12; } else { vol = _vol2; lambda = _lambda21; } if (t == 0) { tau = -Math.log(a + (b - a) * _rand.nextDouble()) / lambda; } else { tau = -Math.log(_rand.nextDouble()) / lambda; } state1 = !state1; t += tau; if (t < expiries[j]) { var += tau * vol * vol; } else { var += (expiries[j] - t + tau) * vol * vol; vols[j][i] = Math.sqrt(var / expiries[j]); j++; while (j < m && t > expiries[j]) { var += (expiries[j] - expiries[j - 1]) * vol * vol; vols[j][i] = Math.sqrt(var / expiries[j]); j++; } var += (t - expiries[j - 1]) * vol * vol; } } } return vols; } }