/** * Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.analytics.math.interpolation.data; import static com.opengamma.analytics.math.matrix.MatrixAlgebraFactory.OG_ALGEBRA; import java.io.Serializable; import org.apache.commons.lang.ObjectUtils; import com.opengamma.analytics.math.linearalgebra.InverseTridiagonalMatrixCalculator; import com.opengamma.analytics.math.linearalgebra.TridiagonalMatrix; import com.opengamma.analytics.math.matrix.DoubleMatrix1D; import com.opengamma.analytics.math.matrix.DoubleMatrix2D; import com.opengamma.util.ArgumentChecker; /** * */ public class Interpolator1DCubicSplineDataBundle implements Interpolator1DDataBundle, Serializable { private final Interpolator1DDataBundle _underlyingData; private double[] _secondDerivatives; private double[][] _secondDerivativesSensitivities; private final double _leftFirstDev; private final double _rightFirstDev; private final boolean _leftNatural; private final boolean _rightNatural; public Interpolator1DCubicSplineDataBundle(final Interpolator1DDataBundle underlyingData) { ArgumentChecker.notNull(underlyingData, "underlying data"); _underlyingData = underlyingData; _leftFirstDev = 0; _rightFirstDev = 0; _leftNatural = true; _rightNatural = true; } /** * Data bundle for a cubic spline * @param underlyingData the data * @param leftGrad The gradient of the function at the left most knot. <b>Note: </b>to leave this unspecified (i.e. natural with zero second derivative), * set the value to Double.POSITIVE_INFINITY * @param rightGrad The gradient of the function at the right most knot. <b>Note: </b>to leave this unspecified (i.e. natural with zero second derivative), * set the value to Double.POSITIVE_INFINITY */ public Interpolator1DCubicSplineDataBundle(final Interpolator1DDataBundle underlyingData, final double leftGrad, final double rightGrad) { ArgumentChecker.notNull(underlyingData, "underlying data"); _underlyingData = underlyingData; if (Double.isInfinite(leftGrad)) { _leftFirstDev = 0; _leftNatural = true; } else { _leftFirstDev = leftGrad; _leftNatural = false; } if (Double.isInfinite(rightGrad)) { _rightFirstDev = 0; _rightNatural = true; } else { _rightFirstDev = leftGrad; _rightNatural = false; } } private double[] calculateSecondDerivative() { final double[] x = getKeys(); final double[] y = getValues(); final int n = x.length; final double[] deltaX = new double[n - 1]; final double[] deltaYOverDeltaX = new double[n - 1]; final double[] oneOverDeltaX = new double[n - 1]; for (int i = 0; i < n - 1; i++) { deltaX[i] = x[i + 1] - x[i]; oneOverDeltaX[i] = 1.0 / deltaX[i]; deltaYOverDeltaX[i] = (y[i + 1] - y[i]) * oneOverDeltaX[i]; } final DoubleMatrix2D inverseTriDiag = getInverseTridiagonalMatrix(deltaX); final DoubleMatrix1D rhsVector = getRHSVector(deltaYOverDeltaX); return ((DoubleMatrix1D) OG_ALGEBRA.multiply(inverseTriDiag, rhsVector)).getData(); } @Override public boolean containsKey(final Double key) { return _underlyingData.containsKey(key); } @Override public Double firstKey() { return _underlyingData.firstKey(); } @Override public Double firstValue() { return _underlyingData.firstValue(); } @Override public Double get(final Double key) { return _underlyingData.get(key); } @Override public InterpolationBoundedValues getBoundedValues(final Double key) { return _underlyingData.getBoundedValues(key); } @Override public double[] getKeys() { return _underlyingData.getKeys(); } @Override public int getLowerBoundIndex(final Double value) { return _underlyingData.getLowerBoundIndex(value); } @Override public Double getLowerBoundKey(final Double value) { return _underlyingData.getLowerBoundKey(value); } @Override public double[] getValues() { return _underlyingData.getValues(); } @Override public Double higherKey(final Double key) { return _underlyingData.higherKey(key); } @Override public Double higherValue(final Double key) { return _underlyingData.higherValue(key); } @Override public Double lastKey() { return _underlyingData.lastKey(); } @Override public Double lastValue() { return _underlyingData.lastValue(); } @Override public int size() { return _underlyingData.size(); } public double[] getSecondDerivatives() { if (_secondDerivatives == null) { _secondDerivatives = calculateSecondDerivative(); } return _secondDerivatives; } //TODO not ideal that it recomputes the inverse matrix public double[][] getSecondDerivativesSensitivities() { if (_secondDerivativesSensitivities == null) { final double[] x = getKeys(); final double[] y = getValues(); final int n = x.length; final double[] deltaX = new double[n - 1]; final double[] deltaYOverDeltaX = new double[n - 1]; final double[] oneOverDeltaX = new double[n - 1]; for (int i = 0; i < n - 1; i++) { deltaX[i] = x[i + 1] - x[i]; oneOverDeltaX[i] = 1.0 / deltaX[i]; deltaYOverDeltaX[i] = (y[i + 1] - y[i]) * oneOverDeltaX[i]; } final DoubleMatrix2D inverseTriDiag = getInverseTridiagonalMatrix(deltaX); final DoubleMatrix2D rhsMatrix = getRHSMatrix(oneOverDeltaX); _secondDerivativesSensitivities = ((DoubleMatrix2D) OG_ALGEBRA.multiply(inverseTriDiag, rhsMatrix)).getData(); } return _secondDerivativesSensitivities; } private DoubleMatrix2D getRHSMatrix(final double[] oneOverDeltaX) { final int n = oneOverDeltaX.length + 1; final double[][] res = new double[n][n]; for (int i = 1; i < n - 1; i++) { res[i][i - 1] = oneOverDeltaX[i - 1]; res[i][i] = -oneOverDeltaX[i] - oneOverDeltaX[i - 1]; res[i][i + 1] = oneOverDeltaX[i]; } if (!_leftNatural) { res[0][0] = oneOverDeltaX[0]; res[0][1] = -oneOverDeltaX[0]; } if (!_rightNatural) { res[n - 1][n - 1] = -oneOverDeltaX[n - 2]; res[n - 2][n - 2] = oneOverDeltaX[n - 2]; } return new DoubleMatrix2D(res); } private DoubleMatrix1D getRHSVector(final double[] deltaYOverDeltaX) { final int n = deltaYOverDeltaX.length + 1; final double[] res = new double[n]; for (int i = 1; i < n - 1; i++) { res[i] = deltaYOverDeltaX[i] - deltaYOverDeltaX[i - 1]; } if (!_leftNatural) { res[0] = _leftFirstDev - deltaYOverDeltaX[0]; } if (!_rightNatural) { res[n - 1] = _rightFirstDev - deltaYOverDeltaX[n - 2]; } return new DoubleMatrix1D(res); } private DoubleMatrix2D getInverseTridiagonalMatrix(final double[] deltaX) { final InverseTridiagonalMatrixCalculator invertor = new InverseTridiagonalMatrixCalculator(); final int n = deltaX.length + 1; final double[] a = new double[n]; final double[] b = new double[n - 1]; final double[] c = new double[n - 1]; for (int i = 1; i < n - 1; i++) { a[i] = (deltaX[i - 1] + deltaX[i]) / 3.0; b[i] = deltaX[i] / 6.0; c[i - 1] = deltaX[i - 1] / 6.0; } // Boundary condition if (_leftNatural) { a[0] = 1.0; b[0] = 0.0; } else { a[0] = -deltaX[0] / 3.0; b[0] = deltaX[0] / 6.0; } if (_rightNatural) { a[n - 1] = 1.0; c[n - 2] = 0.0; } else { a[n - 1] = deltaX[n - 2] / 3.0; c[n - 2] = deltaX[n - 2] / 6.0; } final TridiagonalMatrix tridiagonal = new TridiagonalMatrix(a, b, c); return invertor.evaluate(tridiagonal); } @Override public void setYValueAtIndex(final int index, final double y) { ArgumentChecker.notNegative(index, "index"); if (index >= size()) { throw new IllegalArgumentException("Index was greater than number of data points"); } _underlyingData.setYValueAtIndex(index, y); _secondDerivatives = null; _secondDerivativesSensitivities = null; } @Override public int hashCode() { final int prime = 31; int result = 1; long temp; temp = Double.doubleToLongBits(_leftFirstDev); result = prime * result + (int) (temp ^ (temp >>> 32)); result = prime * result + 1237; temp = Double.doubleToLongBits(_rightFirstDev); result = prime * result + (int) (temp ^ (temp >>> 32)); result = prime * result + 1237; result = prime * result + ((_underlyingData == null) ? 0 : _underlyingData.hashCode()); return result; } @Override public boolean equals(final Object obj) { if (this == obj) { return true; } if (obj == null) { return false; } if (getClass() != obj.getClass()) { return false; } final Interpolator1DCubicSplineDataBundle other = (Interpolator1DCubicSplineDataBundle) obj; if (!ObjectUtils.equals(_underlyingData, other._underlyingData)) { return false; } if (Double.doubleToLongBits(_leftFirstDev) != Double.doubleToLongBits(other._leftFirstDev)) { return false; } if (_leftNatural != other._leftNatural) { return false; } if (Double.doubleToLongBits(_rightFirstDev) != Double.doubleToLongBits(other._rightFirstDev)) { return false; } if (_rightNatural != other._rightNatural) { return false; } return true; } }