package ch.akuhn.matrix; import java.util.Arrays; /** Matrix where *<CODE>a<SUB>ij</SUB> = a<SUB>ji</SUB></CODE> for all elements. *<P> * @author Adrian Kuhn * */ public class SymmetricMatrix extends DenseMatrix { public SymmetricMatrix(int size) { super(size,size); } public SymmetricMatrix(double[][] values) { super(values); } @Override protected void assertInvariant() throws IllegalArgumentException { for (int n = 0; n < values.length; n++) { if (values[n].length != (n + 1)) throw new IllegalArgumentException(); } } @Override protected double[][] makeValues(int rows, int columns) { assert rows == columns; double[][] values = new double[rows][]; for (int n = 0; n < values.length; n++) values[n] = new double[n + 1]; return values; } @Override public int columnCount() { return rowCount(); } @Override public double get(int row, int column) { return row > column ? values[row][column] : values[column][row]; } @Override public double put(int row, int column, double value) { return row > column ? (values[row][column] = value) : (values[column][row] = value); } @Override public int rowCount() { return values.length; } public static DenseMatrix fromSquare(double[][] square) { double[][] jagged = new double[square.length][]; for (int i = 0; i < jagged.length; i++) { assert square[i].length == square.length; jagged[i] = Arrays.copyOf(square[i], i+1); } return new SymmetricMatrix(jagged); } public static DenseMatrix fromJagged(double[][] values) { return new SymmetricMatrix(values); } @Override public double[][] unwrap() { return values; } @Override public double[] rowwiseMean() { double[] mean = new double[rowCount()]; for (int i = 0; i < values.length; i++) { for (int j = 0; j < i; j++) { mean[i] += values[i][j]; mean[j] += values[i][j]; } } for (int n = 0; n < mean.length; n++) mean[n] /= mean.length; return mean; } @Override public Vector mult(Vector v) { assert v.size() == values.length; double[] mult = new double[v.size()]; for (int i = 0; i < values.length; i++) { for (int j = 0; j < i; j++) { mult[i] += values[i][j] * v.get(j); mult[j] += values[i][j] * v.get(i); } } return Vector.wrap(mult); } }