/** * Copyright (C) 2011 - present by OpenGamma Inc. and the OpenGamma group of companies * * Please see distribution for license. */ package com.opengamma.financial.analytics; import it.unimi.dsi.fastutil.objects.ObjectArrayList; import java.util.Arrays; import org.apache.commons.lang.Validate; import com.opengamma.financial.analytics.QuickSorter.ArrayQuickSorter; import com.opengamma.util.ParallelArrayBinarySort; /** * @param <S> * @param <T> */ public abstract class LabelledMatrix2D<S extends Comparable<S>, T extends Comparable<T>> { private final S[] _xKeys; private final Object[] _xLabels; private final String _xTitle; private final T[] _yKeys; private final Object[] _yLabels; private final String _yTitle; private final double[][] _values; private final String _valuesTitle; public LabelledMatrix2D(final S[] xKeys, final T[] yKeys, final double[][] values) { this(xKeys, LabelledMatrixUtils.toString(xKeys), yKeys, LabelledMatrixUtils.toString(yKeys), values); } public LabelledMatrix2D(final S[] xKeys, final Object[] xLabels, final T[] yKeys, final Object[] yLabels, final double[][] values) { this(xKeys, xLabels, null, yKeys, yLabels, null, values, null); } public LabelledMatrix2D(final S[] xKeys, final Object[] xLabels, final String xTitle, final T[] yKeys, final Object[] yLabels, final String yTitle, final double[][] values, final String valuesTitle) { Validate.notNull(xKeys, "x keys"); final int m = xKeys.length; Validate.notNull(xLabels, "x labels"); Validate.isTrue(xLabels.length == m); Validate.notNull(yKeys, "y keys"); final int n = yKeys.length; Validate.notNull(yLabels, "y labels"); Validate.notNull(yLabels.length == n); Validate.notNull(values, "values"); Validate.isTrue(values.length == n, "number of rows of data and y keys must be the same length"); _xKeys = Arrays.copyOf(xKeys, m); _yKeys = Arrays.copyOf(yKeys, n); _xLabels = new Object[m]; _yLabels = new Object[n]; _xTitle = xTitle; _yTitle = yTitle; _values = new double[n][m]; for (int i = 0; i < n; i++) { Validate.isTrue(values[i].length == m, "number of columns of data and x keys must be the same length"); _yLabels[i] = yLabels[i]; for (int j = 0; j < m; j++) { if (i == 0) { _xLabels[j] = xLabels[j]; } _values[i][j] = values[i][j]; } } _valuesTitle = valuesTitle; quickSortX(); quickSortY(); } public S[] getXKeys() { return _xKeys; } public Object[] getXLabels() { return _xLabels; } public String getXTitle() { return _xTitle; } public T[] getYKeys() { return _yKeys; } public Object[] getYLabels() { return _yLabels; } public String getYTitle() { return _yTitle; } public double[][] getValues() { return _values; } public String getValuesTitle() { return _valuesTitle; } public abstract <X> int compareX(S key1, S key2, X tolerance); public abstract <Y> int compareY(T key1, T key2, Y tolerance); public abstract LabelledMatrix2D<S, T> getMatrix(S[] xKeys, Object[] xLabels, String xTitle, T[] yKeys, Object[] yLabels, String yTitle, double[][] values, String valuesTitle); public abstract LabelledMatrix2D<S, T> getMatrix(S[] xKeys, Object[] xLabels, T[] yKeys, Object[] yLabels, double[][] values); //TODO this needs rewriting //TODO this ignores labels - using the original labels first and only using the labels from other when a new row / column is added public <X, Y> LabelledMatrix2D<S, T> add(final LabelledMatrix2D<S, T> other, final X xTolerance, final Y yTolerance) { Validate.notNull(other, "labelled matrix"); final S[] otherXKeys = other.getXKeys(); final Object[] otherXLabels = other.getXLabels(); final T[] otherYKeys = other.getYKeys(); final Object[] otherYLabels = other.getYLabels(); final S[] originalXKeys = getXKeys(); final Object[] originalXLabels = getXLabels(); final T[] originalYKeys = getYKeys(); final Object[] originalYLabels = getYLabels(); final int m1 = originalXKeys.length; final int m2 = otherXKeys.length; final int n1 = originalYKeys.length; final int n2 = otherYKeys.length; final ObjectArrayList<S> newXKeysList = new ObjectArrayList<S>(originalXKeys); final ObjectArrayList<Object> newXLabelsList = new ObjectArrayList<Object>(originalXLabels); final ObjectArrayList<T> newYKeysList = new ObjectArrayList<T>(originalYKeys); final ObjectArrayList<Object> newYLabelsList = new ObjectArrayList<Object>(originalYLabels); for (int i = 0; i < m2; i++) { final int index = binarySearchInXWithTolerance(originalXKeys, otherXKeys[i], xTolerance); if (index < 0) { newXKeysList.add(otherXKeys[i]); newXLabelsList.add(otherXLabels[i]); } } for (int i = 0; i < n2; i++) { final int index = binarySearchInYWithTolerance(originalYKeys, otherYKeys[i], yTolerance); if (index < 0) { newYKeysList.add(otherYKeys[i]); newYLabelsList.add(otherYLabels[i]); } } final S[] newXKeys = newXKeysList.toArray(originalXKeys); final Object[] newXLabels = newXLabelsList.toArray(); final T[] newYKeys = newYKeysList.toArray(originalYKeys); final Object[] newYLabels = newYLabelsList.toArray(); ParallelArrayBinarySort.parallelBinarySort(newXKeys, newXLabels); ParallelArrayBinarySort.parallelBinarySort(newYKeys, newYLabels); final int totalX = newXKeys.length; final int totalY = newYKeys.length; final double[][] newValues = new double[totalY][totalX]; for (int i = 0; i < n1; i++) { final int indexY = binarySearchInYWithTolerance(newYKeys, originalYKeys[i], yTolerance); for (int j = 0; j < m1; j++) { final int indexX = binarySearchInXWithTolerance(newXKeys, originalXKeys[j], xTolerance); newValues[indexY][indexX] = _values[i][j]; } } for (int i = 0; i < n2; i++) { final int indexY = binarySearchInYWithTolerance(newYKeys, otherYKeys[i], yTolerance); for (int j = 0; j < m2; j++) { final int indexX = binarySearchInXWithTolerance(newXKeys, otherXKeys[j], xTolerance); newValues[indexY][indexX] += other._values[i][j]; } } return getMatrix(newXKeys, newXLabels, getXTitle(), newYKeys, newYLabels, getYTitle(), newValues, getValuesTitle()); } protected <X> int binarySearchInXWithTolerance(final S[] keys, final S key, final X tolerance) { int low = 0; int high = keys.length - 1; while (low <= high) { final int mid = (low + high) >>> 1; final S midVal = keys[mid]; final int comparison = compareX(key, midVal, tolerance); if (comparison == 0) { return mid; } else if (comparison == 1) { low = mid + 1; } else { high = mid - 1; } } return -(low + 1); } protected <Y> int binarySearchInYWithTolerance(final T[] keys, final T key, final Y tolerance) { int low = 0; int high = keys.length - 1; while (low <= high) { final int mid = (low + high) >>> 1; final T midVal = keys[mid]; final int comparison = compareY(key, midVal, tolerance); if (comparison == 0) { return mid; } else if (comparison > 0) { low = mid + 1; } else { high = mid - 1; } } return -(low + 1); } private void quickSortX() { (new ArrayQuickSorter<S>(_xKeys) { @Override protected int compare(final S first, final S second) { return first.compareTo(second); } @Override protected void swap(final int first, final int second) { super.swap(first, second); swap(_xLabels, first, second); final int y = _yKeys.length; for (int iy = 0; iy < y; iy++) { swap(_values[iy], first, second); } } }).sort(); } private void quickSortY() { (new ArrayQuickSorter<T>(_yKeys) { @Override protected int compare(final T first, final T second) { return first.compareTo(second); } @Override protected void swap(final int first, final int second) { super.swap(first, second); swap(_yLabels, first, second); swap(_values, first, second); } }).sort(); } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(_values); result = prime * result + ((_valuesTitle == null) ? 0 : _valuesTitle.hashCode()); result = prime * result + Arrays.hashCode(_xKeys); result = prime * result + Arrays.hashCode(_xLabels); result = prime * result + ((_xTitle == null) ? 0 : _xTitle.hashCode()); result = prime * result + Arrays.hashCode(_yKeys); result = prime * result + Arrays.hashCode(_yLabels); result = prime * result + ((_yTitle == null) ? 0 : _yTitle.hashCode()); return result; } @SuppressWarnings("rawtypes") @Override public boolean equals(final Object obj) { if (this == obj) { return true; } if (!(obj instanceof LabelledMatrix2D)) { return false; } final LabelledMatrix2D other = (LabelledMatrix2D) obj; final double[][] otherValues = other._values; for (int i = 0; i < _values.length; i++) { if (!Arrays.equals(_values[i], otherValues[i])) { return false; } } if (_valuesTitle == null) { if (other._valuesTitle != null) { return false; } } else if (!_valuesTitle.equals(other._valuesTitle)) { return false; } if (!Arrays.equals(_xKeys, other._xKeys)) { return false; } if (!Arrays.equals(_xLabels, other._xLabels)) { return false; } if (_xTitle == null) { if (other._xTitle != null) { return false; } } else if (!_xTitle.equals(other._xTitle)) { return false; } if (!Arrays.equals(_yKeys, other._yKeys)) { return false; } if (!Arrays.equals(_yLabels, other._yLabels)) { return false; } if (_yTitle == null) { if (other._yTitle != null) { return false; } } else if (!_yTitle.equals(other._yTitle)) { return false; } return true; } }