/**
* Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.analytics.math.curve;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.opengamma.util.ArgumentChecker;
/**
* Shifts an {@link InterpolatedDoublesCurve}. If the <i>x</i> value(s) of the shift(s) are not in the nodal points of the
* original curve, they are added (with shift) to the nodal points of the new curve.
*/
public class InterpolatedCurveShiftFunction implements CurveShiftFunction<InterpolatedDoublesCurve> {
/**
* {@inheritDoc}
*/
@Override
public InterpolatedDoublesCurve evaluate(final InterpolatedDoublesCurve curve, final double shift) {
ArgumentChecker.notNull(curve, "curve");
return evaluate(curve, shift, "PARALLEL_SHIFT_" + curve.getName());
}
/**
* {@inheritDoc}
*/
@Override
public InterpolatedDoublesCurve evaluate(final InterpolatedDoublesCurve curve, final double shift, final String newName) {
ArgumentChecker.notNull(curve, "curve");
final double[] xData = curve.getXDataAsPrimitive();
final double[] yData = curve.getYDataAsPrimitive();
final double[] shiftedY = new double[yData.length];
int i = 0;
for (final double y : yData) {
shiftedY[i++] = y + shift;
}
return InterpolatedDoublesCurve.fromSorted(xData, shiftedY, curve.getInterpolator(), newName);
}
/**
* {@inheritDoc}
*/
@Override
public InterpolatedDoublesCurve evaluate(final InterpolatedDoublesCurve curve, final double x, final double shift) {
ArgumentChecker.notNull(curve, "curve");
return evaluate(curve, x, shift, "SINGLE_SHIFT_" + curve.getName());
}
/**
* {@inheritDoc}
*/
@Override
public InterpolatedDoublesCurve evaluate(final InterpolatedDoublesCurve curve, final double x, final double shift, final String newName) {
ArgumentChecker.notNull(curve, "curve");
final double[] xData = curve.getXDataAsPrimitive();
final double[] yData = curve.getYDataAsPrimitive();
final int n = xData.length;
final int index = Arrays.binarySearch(xData, x);
if (index >= 0) {
final double[] shiftedY = Arrays.copyOf(curve.getYDataAsPrimitive(), n);
shiftedY[index] += shift;
return InterpolatedDoublesCurve.fromSorted(xData, shiftedY, curve.getInterpolator(), newName);
}
final double[] newX = new double[n + 1];
final double[] newY = new double[n + 1];
for (int i = 0; i < n; i++) {
newX[i] = xData[i];
newY[i] = yData[i];
}
newX[n] = x;
newY[n] = curve.getYValue(x) + shift;
return InterpolatedDoublesCurve.from(newX, newY, curve.getInterpolator(), newName);
}
/**
* {@inheritDoc}
*/
@Override
public InterpolatedDoublesCurve evaluate(final InterpolatedDoublesCurve curve, final double[] xShift, final double[] yShift) {
ArgumentChecker.notNull(curve, "curve");
return evaluate(curve, xShift, yShift, "MULTIPLE_POINT_SHIFT_" + curve.getName());
}
/**
* {@inheritDoc}
*/
@Override
public InterpolatedDoublesCurve evaluate(final InterpolatedDoublesCurve curve, final double[] xShift, final double[] yShift, final String newName) {
ArgumentChecker.notNull(curve, "curve");
ArgumentChecker.notNull(xShift, "x shifts");
ArgumentChecker.notNull(yShift, "y shifts");
ArgumentChecker.isTrue(xShift.length == yShift.length, "number of x shifts {} must equal number of y shifts {}", xShift.length, yShift.length);
if (xShift.length == 0) {
return InterpolatedDoublesCurve.from(curve.getXDataAsPrimitive(), curve.getYDataAsPrimitive(), curve.getInterpolator(), newName);
}
final List<Double> newX = new ArrayList<>(Arrays.asList(curve.getXData()));
final List<Double> newY = new ArrayList<>(Arrays.asList(curve.getYData()));
for (int i = 0; i < xShift.length; i++) {
final int index = newX.indexOf(xShift[i]);
if (index >= 0) {
newY.set(index, newY.get(index) + yShift[i]);
} else {
newX.add(xShift[i]);
newY.add(curve.getYValue(xShift[i]) + yShift[i]);
}
}
return InterpolatedDoublesCurve.from(newX, newY, curve.getInterpolator(), newName);
}
}