/** * Copyright (C) 2015 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.strata.pricer.sensitivity; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.function.BiFunction; import java.util.function.Function; import org.joda.beans.MetaProperty; import com.google.common.collect.ImmutableMap; import com.opengamma.strata.basics.currency.Currency; import com.opengamma.strata.basics.currency.CurrencyAmount; import com.opengamma.strata.collect.array.DoubleArray; import com.opengamma.strata.collect.tuple.Pair; import com.opengamma.strata.market.curve.Curve; import com.opengamma.strata.market.param.CurrencyParameterSensitivities; import com.opengamma.strata.pricer.DiscountFactors; import com.opengamma.strata.pricer.SimpleDiscountFactors; import com.opengamma.strata.pricer.ZeroRateDiscountFactors; import com.opengamma.strata.pricer.bond.ImmutableLegalEntityDiscountingProvider; import com.opengamma.strata.pricer.bond.LegalEntityDiscountingProvider; import com.opengamma.strata.pricer.rate.ImmutableRatesProvider; import com.opengamma.strata.pricer.rate.RatesProvider; /** * Computes the curve parameter sensitivity by finite difference. * <p> * This is based on an {@link ImmutableRatesProvider} or {@link ImmutableLegalEntityDiscountingProvider}, * and calculates the sensitivity by finite difference. */ public class RatesFiniteDifferenceSensitivityCalculator { /** * Default implementation. The shift is one basis point (0.0001). */ public static final RatesFiniteDifferenceSensitivityCalculator DEFAULT = new RatesFiniteDifferenceSensitivityCalculator(1.0E-4); /** * The shift used for finite difference. */ private final double shift; /** * Create an instance of the finite difference calculator. * * @param shift the shift used in the finite difference computation */ public RatesFiniteDifferenceSensitivityCalculator(double shift) { this.shift = shift; } //------------------------------------------------------------------------- /** * Computes the first order sensitivities of a function of a RatesProvider to a double by finite difference. * <p> * The finite difference is computed by forward type. * The function should return a value in the same currency for any rate provider. * * @param provider the rates provider * @param valueFn the function from a rate provider to a currency amount for which the sensitivity should be computed * @return the curve sensitivity */ public CurrencyParameterSensitivities sensitivity( RatesProvider provider, Function<ImmutableRatesProvider, CurrencyAmount> valueFn) { ImmutableRatesProvider immProv = provider.toImmutableRatesProvider(); CurrencyAmount valueInit = valueFn.apply(immProv); CurrencyParameterSensitivities discounting = sensitivity( immProv, immProv.getDiscountCurves(), (base, bumped) -> base.toBuilder().discountCurves(bumped).build(), valueFn, valueInit); CurrencyParameterSensitivities forward = sensitivity( immProv, immProv.getIndexCurves(), (base, bumped) -> base.toBuilder().indexCurves(bumped).build(), valueFn, valueInit); return discounting.combinedWith(forward); } // computes the sensitivity with respect to the curves private <T> CurrencyParameterSensitivities sensitivity( ImmutableRatesProvider provider, Map<T, Curve> baseCurves, BiFunction<ImmutableRatesProvider, Map<T, Curve>, ImmutableRatesProvider> storeBumpedFn, Function<ImmutableRatesProvider, CurrencyAmount> valueFn, CurrencyAmount valueInit) { CurrencyParameterSensitivities result = CurrencyParameterSensitivities.empty(); for (Entry<T, Curve> entry : baseCurves.entrySet()) { Curve curve = entry.getValue(); DoubleArray sensitivity = DoubleArray.of(curve.getParameterCount(), i -> { Curve dscBumped = curve.withParameter(i, curve.getParameter(i) + shift); Map<T, Curve> mapBumped = new HashMap<>(baseCurves); mapBumped.put(entry.getKey(), dscBumped); ImmutableRatesProvider providerDscBumped = storeBumpedFn.apply(provider, mapBumped); return (valueFn.apply(providerDscBumped).getAmount() - valueInit.getAmount()) / shift; }); result = result.combinedWith(curve.createParameterSensitivity(valueInit.getCurrency(), sensitivity)); } return result; } //------------------------------------------------------------------------- /** * Computes the first order sensitivities of a function of a LegalEntityDiscountingProvider to a double by finite difference. * <p> * The finite difference is computed by forward type. * The function should return a value in the same currency for any rates provider of LegalEntityDiscountingProvider. * * @param provider the rates provider * @param valueFn the function from a rate provider to a currency amount for which the sensitivity should be computed * @return the curve sensitivity */ public CurrencyParameterSensitivities sensitivity( LegalEntityDiscountingProvider provider, Function<ImmutableLegalEntityDiscountingProvider, CurrencyAmount> valueFn) { ImmutableLegalEntityDiscountingProvider immProv = provider.toImmutableLegalEntityDiscountingProvider(); CurrencyAmount valueInit = valueFn.apply(immProv); CurrencyParameterSensitivities discounting = sensitivity( immProv, valueFn, ImmutableLegalEntityDiscountingProvider.meta().repoCurves(), valueInit); CurrencyParameterSensitivities forward = sensitivity( immProv, valueFn, ImmutableLegalEntityDiscountingProvider.meta().issuerCurves(), valueInit); return discounting.combinedWith(forward); } private <T> CurrencyParameterSensitivities sensitivity( ImmutableLegalEntityDiscountingProvider provider, Function<ImmutableLegalEntityDiscountingProvider, CurrencyAmount> valueFn, MetaProperty<ImmutableMap<Pair<T, Currency>, DiscountFactors>> metaProperty, CurrencyAmount valueInit) { ImmutableMap<Pair<T, Currency>, DiscountFactors> baseCurves = metaProperty.get(provider); CurrencyParameterSensitivities result = CurrencyParameterSensitivities.empty(); for (Pair<T, Currency> key : baseCurves.keySet()) { DiscountFactors discountFactors = baseCurves.get(key); Curve curve = checkDiscountFactors(discountFactors); int paramCount = curve.getParameterCount(); double[] sensitivity = new double[paramCount]; for (int i = 0; i < paramCount; i++) { Curve dscBumped = curve.withParameter(i, curve.getParameter(i) + shift); Map<Pair<T, Currency>, DiscountFactors> mapBumped = new HashMap<>(baseCurves); mapBumped.put(key, createDiscountFactors(discountFactors, dscBumped)); ImmutableLegalEntityDiscountingProvider providerDscBumped = provider.toBuilder().set(metaProperty, mapBumped).build(); sensitivity[i] = (valueFn.apply(providerDscBumped).getAmount() - valueInit.getAmount()) / shift; } result = result.combinedWith( curve.createParameterSensitivity(valueInit.getCurrency(), DoubleArray.copyOf(sensitivity))); } return result; } //------------------------------------------------------------------------- // check that the discountFactors is ZeroRateDiscountFactors or SimpleDiscountFactors private Curve checkDiscountFactors(DiscountFactors discountFactors) { if (discountFactors instanceof ZeroRateDiscountFactors) { return ((ZeroRateDiscountFactors) discountFactors).getCurve(); } else if (discountFactors instanceof SimpleDiscountFactors) { return ((SimpleDiscountFactors) discountFactors).getCurve(); } throw new IllegalArgumentException("Not supported"); } // return correct instance of DiscountFactors private DiscountFactors createDiscountFactors(DiscountFactors originalDsc, Curve bumpedCurve) { if (originalDsc instanceof ZeroRateDiscountFactors) { return ZeroRateDiscountFactors.of(originalDsc.getCurrency(), originalDsc.getValuationDate(), bumpedCurve); } else if (originalDsc instanceof SimpleDiscountFactors) { return SimpleDiscountFactors.of(originalDsc.getCurrency(), originalDsc.getValuationDate(), bumpedCurve); } throw new IllegalArgumentException("Not supported"); } }