/**
* Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.analytics.math.matrix;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang.Validate;
import com.opengamma.analytics.math.linearalgebra.TridiagonalMatrix;
import com.opengamma.util.ArgumentChecker;
/**
* An absolutely minimal implementation of matrix algebra - only various multiplications covered. For more advanced
* stuff (e.g. calculating the inverse) use {@link ColtMatrixAlgebra} or {@link CommonsMatrixAlgebra}
*/
public class OGMatrixAlgebra extends MatrixAlgebra {
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public double getCondition(final Matrix<?> m) {
throw new NotImplementedException();
}
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public double getDeterminant(final Matrix<?> m) {
throw new NotImplementedException();
}
/**
* {@inheritDoc}
*/
@Override
public double getInnerProduct(final Matrix<?> m1, final Matrix<?> m2) {
Validate.notNull(m1, "m1");
Validate.notNull(m2, "m2");
if (m1 instanceof DoubleMatrix1D && m2 instanceof DoubleMatrix1D) {
final double[] a = ((DoubleMatrix1D) m1).getData();
final double[] b = ((DoubleMatrix1D) m2).getData();
final int l = a.length;
Validate.isTrue(l == b.length, "Matrix size mismacth");
double sum = 0.0;
for (int i = 0; i < l; i++) {
sum += a[i] * b[i];
}
return sum;
}
throw new IllegalArgumentException("Can only find inner product of DoubleMatrix1D; have " + m1.getClass() +
" and " + m2.getClass());
}
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public DoubleMatrix2D getInverse(final Matrix<?> m) {
throw new NotImplementedException();
}
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public double getNorm1(final Matrix<?> m) {
throw new NotImplementedException();
}
/**
* {@inheritDoc} This is only implemented for {@link DoubleMatrix1D}.
* @throws IllegalArgumentException If the matrix is not a {@link DoubleMatrix1D}
*/
@Override
public double getNorm2(final Matrix<?> m) {
Validate.notNull(m, "m");
if (m instanceof DoubleMatrix1D) {
final double[] a = ((DoubleMatrix1D) m).getData();
final int l = a.length;
double sum = 0.0;
for (int i = 0; i < l; i++) {
sum += a[i] * a[i];
}
return Math.sqrt(sum);
} else if (m instanceof DoubleMatrix2D) {
throw new NotImplementedException();
}
throw new IllegalArgumentException("Can only find norm2 of a DoubleMatrix1D; have " + m.getClass());
}
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public double getNormInfinity(final Matrix<?> m) {
throw new NotImplementedException();
}
/**
* {@inheritDoc}
*/
@Override
public DoubleMatrix2D getOuterProduct(final Matrix<?> m1, final Matrix<?> m2) {
Validate.notNull(m1, "m1");
Validate.notNull(m2, "m2");
if (m1 instanceof DoubleMatrix1D && m2 instanceof DoubleMatrix1D) {
final double[] a = ((DoubleMatrix1D) m1).getData();
final double[] b = ((DoubleMatrix1D) m2).getData();
final int m = a.length;
final int n = b.length;
final double[][] res = new double[m][n];
int i, j;
for (i = 0; i < m; i++) {
for (j = 0; j < n; j++) {
res[i][j] = a[i] * b[j];
}
}
return new DoubleMatrix2D(res);
}
throw new IllegalArgumentException("Can only find outer product of DoubleMatrix1D; have " + m1.getClass() +
" and " + m2.getClass());
}
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public DoubleMatrix2D getPower(final Matrix<?> m, final int p) {
throw new NotImplementedException();
}
/**
* {@inheritDoc}
*/
@Override
public double getTrace(final Matrix<?> m) {
Validate.notNull(m, "m");
if (m instanceof DoubleMatrix2D) {
final double[][] data = ((DoubleMatrix2D) m).getData();
final int rows = data.length;
Validate.isTrue(rows == data[0].length, "Matrix not square");
double sum = 0.0;
for (int i = 0; i < rows; i++) {
sum += data[i][i];
}
return sum;
}
throw new IllegalArgumentException("Can only take the trace of DoubleMatrix2D; have " + m.getClass());
}
/**
* {@inheritDoc}
*/
@Override
public DoubleMatrix2D getTranspose(final Matrix<?> m) {
Validate.notNull(m, "m");
if (m instanceof IdentityMatrix) {
return (IdentityMatrix) m;
}
if (m instanceof DoubleMatrix2D) {
final double[][] data = ((DoubleMatrix2D) m).getData();
final int rows = data.length;
final int cols = data[0].length;
final double[][] res = new double[cols][rows];
int i, j;
for (i = 0; i < cols; i++) {
for (j = 0; j < rows; j++) {
res[i][j] = data[j][i];
}
}
return new DoubleMatrix2D(res);
}
throw new IllegalArgumentException("Can only take transpose of DoubleMatrix2D; have " + m.getClass());
}
/**
* {@inheritDoc} The following combinations of input matrices m1 and m2 are allowed:
* <ul>
* <li>m1 = 2-D matrix, m2 = 2-D matrix, returns $\mathbf{C} = \mathbf{AB}$
* <li>m1 = 2-D matrix, m2 = 1-D matrix, returns $\mathbf{C} = \mathbf{A}b$
* <li>m1 = 1-D matrix, m2 = 2-D matrix, returns $\mathbf{C} = a^T\mathbf{B}$
* </ul>
*/
@Override
public Matrix<?> multiply(final Matrix<?> m1, final Matrix<?> m2) {
Validate.notNull(m1, "m1");
Validate.notNull(m2, "m2");
if (m1 instanceof IdentityMatrix) {
if (m2 instanceof IdentityMatrix) {
return multiply((IdentityMatrix) m1, (IdentityMatrix) m2);
} else if (m2 instanceof DoubleMatrix1D) {
return multiply((IdentityMatrix) m1, (DoubleMatrix1D) m2);
} else if (m2 instanceof DoubleMatrix2D) {
return multiply((IdentityMatrix) m1, (DoubleMatrix2D) m2);
}
throw new IllegalArgumentException("can only handle IdentityMatrix by DoubleMatrix2D or DoubleMatrix1D, have " +
m1.getClass() + " and " + m2.getClass());
}
if (m2 instanceof IdentityMatrix) {
if (m1 instanceof DoubleMatrix1D) {
return multiply((DoubleMatrix1D) m1, (IdentityMatrix) m2);
} else if (m1 instanceof DoubleMatrix2D) {
return multiply((DoubleMatrix2D) m1, (IdentityMatrix) m2);
}
throw new IllegalArgumentException("can only handle DoubleMatrix2D or DoubleMatrix1D by IdentityMatrix, have " +
m1.getClass() + " and " + m2.getClass());
}
if (m1 instanceof TridiagonalMatrix && m2 instanceof DoubleMatrix1D) {
return multiply((TridiagonalMatrix) m1, (DoubleMatrix1D) m2);
} else if (m1 instanceof DoubleMatrix1D && m2 instanceof TridiagonalMatrix) {
return multiply((DoubleMatrix1D) m1, (TridiagonalMatrix) m2);
} else if (m1 instanceof DoubleMatrix2D && m2 instanceof DoubleMatrix2D) {
return multiply((DoubleMatrix2D) m1, (DoubleMatrix2D) m2);
} else if (m1 instanceof DoubleMatrix2D && m2 instanceof DoubleMatrix1D) {
return multiply((DoubleMatrix2D) m1, (DoubleMatrix1D) m2);
} else if (m1 instanceof DoubleMatrix1D && m2 instanceof DoubleMatrix2D) {
return multiply((DoubleMatrix1D) m1, (DoubleMatrix2D) m2);
}
throw new IllegalArgumentException(
"Can only multiply two DoubleMatrix2D; a DoubleMatrix2D and a DoubleMatrix1D; or a DoubleMatrix1D and a DoubleMatrix2D. have " +
m1.getClass() + " and " + m2.getClass());
}
/**
* {@inheritDoc}
* @throws NotImplementedException
*/
@Override
public DoubleMatrix2D getPower(final Matrix<?> m, final double p) {
throw new NotImplementedException();
}
private DoubleMatrix2D multiply(final IdentityMatrix idet, final DoubleMatrix2D m) {
ArgumentChecker.isTrue(idet.getSize() == m.getNumberOfRows(),
"size of identity matrix ({}) does not match number or rows of m ({})", idet.getSize(), m.getNumberOfRows());
return m;
}
private DoubleMatrix2D multiply(final DoubleMatrix2D m, final IdentityMatrix idet) {
ArgumentChecker.isTrue(idet.getSize() == m.getNumberOfColumns(),
"size of identity matrix ({}) does not match number or columns of m ({})", idet.getSize(),
m.getNumberOfColumns());
return m;
}
private IdentityMatrix multiply(final IdentityMatrix i1, final IdentityMatrix i2) {
ArgumentChecker.isTrue(i1.getSize() == i2.getSize(),
"size of identity matrix 1 ({}) does not match size of identity matrix 2 ({})", i1.getSize(), i2.getSize());
return i1;
}
private DoubleMatrix2D multiply(final DoubleMatrix2D m1, final DoubleMatrix2D m2) {
final double[][] a = m1.getData();
final double[][] b = m2.getData();
final int p = b.length;
Validate.isTrue(
a[0].length == p,
"Matrix size mismatch. m1 is " + m1.getNumberOfRows() + " by " + m1.getNumberOfColumns() + ", but m2 is " +
m2.getNumberOfRows() + " by " + m2.getNumberOfColumns());
final int m = a.length;
final int n = b[0].length;
double sum;
final double[][] res = new double[m][n];
int i, j, k;
for (i = 0; i < m; i++) {
for (j = 0; j < n; j++) {
sum = 0.0;
for (k = 0; k < p; k++) {
sum += a[i][k] * b[k][j];
}
res[i][j] = sum;
}
}
return new DoubleMatrix2D(res);
}
private DoubleMatrix1D multiply(final IdentityMatrix matrix, final DoubleMatrix1D vector) {
ArgumentChecker.isTrue(matrix.getSize() == vector.getNumberOfElements(),
"size of identity matrix ({}) does not match size of vector ({})", matrix.getSize(),
vector.getNumberOfElements());
return vector;
}
private DoubleMatrix1D multiply(final DoubleMatrix1D vector, final IdentityMatrix matrix) {
ArgumentChecker.isTrue(matrix.getSize() == vector.getNumberOfElements(),
"size of identity matrix ({}) does not match size of vector ({})", matrix.getSize(),
vector.getNumberOfElements());
return vector;
}
private DoubleMatrix1D multiply(final DoubleMatrix2D matrix, final DoubleMatrix1D vector) {
final double[][] a = matrix.getData();
final double[] b = vector.getData();
final int n = b.length;
Validate.isTrue(a[0].length == n, "Matrix/vector size mismatch");
final int m = a.length;
final double[] res = new double[m];
int i, j;
double sum;
for (i = 0; i < m; i++) {
sum = 0.0;
for (j = 0; j < n; j++) {
sum += a[i][j] * b[j];
}
res[i] = sum;
}
return new DoubleMatrix1D(res);
}
private DoubleMatrix1D multiply(final TridiagonalMatrix matrix, final DoubleMatrix1D vector) {
final double[] a = matrix.getLowerSubDiagonalData();
final double[] b = matrix.getDiagonalData();
final double[] c = matrix.getUpperSubDiagonalData();
final double[] x = vector.getData();
final int n = x.length;
Validate.isTrue(b.length == n, "Matrix/vector size mismatch");
final double[] res = new double[n];
int i;
res[0] = b[0] * x[0] + c[0] * x[1];
res[n - 1] = b[n - 1] * x[n - 1] + a[n - 2] * x[n - 2];
for (i = 1; i < n - 1; i++) {
res[i] = a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1];
}
return new DoubleMatrix1D(res);
}
private DoubleMatrix1D multiply(final DoubleMatrix1D vector, final DoubleMatrix2D matrix) {
final double[] a = vector.getData();
final double[][] b = matrix.getData();
final int n = a.length;
Validate.isTrue(b.length == n, "Matrix/vector size mismatch");
final int m = b[0].length;
final double[] res = new double[m];
int i, j;
double sum;
for (i = 0; i < m; i++) {
sum = 0.0;
for (j = 0; j < n; j++) {
sum += a[j] * b[j][i];
}
res[i] = sum;
}
return new DoubleMatrix1D(res);
}
private DoubleMatrix1D multiply(final DoubleMatrix1D vector, final TridiagonalMatrix matrix) {
final double[] a = matrix.getLowerSubDiagonalData();
final double[] b = matrix.getDiagonalData();
final double[] c = matrix.getUpperSubDiagonalData();
final double[] x = vector.getData();
final int n = x.length;
Validate.isTrue(b.length == n, "Matrix/vector size mismatch");
final double[] res = new double[n];
int i;
res[0] = b[0] * x[0] + a[0] * x[1];
res[n - 1] = b[n - 1] * x[n - 1] + c[n - 2] * x[n - 2];
for (i = 1; i < n - 1; i++) {
res[i] = a[i] * x[i + 1] + b[i] * x[i] + c[i - 1] * x[i - 1];
}
return new DoubleMatrix1D(res);
}
}