package ch.akuhn.matrix;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
public class MatrixTest {
private static final double epsilon = Double.MIN_VALUE;
@Test
public void shouldMultiplySparseMatrixAndVector() {
Vector x = Vector.from(8,9,10,11,12);
Matrix A = new SparseMatrix(new double[][] {
{0,1,0,2,0},
{3,0,4,0,5},
{0,6,0,7,0}});
assertEquals(5, x.size());
assertEquals(3, A.rowCount());
assertEquals(5, A.columnCount());
assertEquals(2, A.get(0,3), epsilon);
assertEquals(3, A.get(1,0), epsilon);
Vector y = A.mult(x);
assertEquals(3, y.size());
assertEquals(1*9+2*11, y.get(0), epsilon);
assertEquals(3*8+4*10+5*12, y.get(1), epsilon);
assertEquals(6*9+11*7, y.get(2), epsilon);
}
@Test
public void shouldMultiplyTransposedSparseMatrixAndVector() {
Vector x = Vector.from(8,9,10);
Matrix A = new SparseMatrix(new double[][] {
{0,1,0,2,0},
{3,0,4,0,5},
{0,6,0,7,0}});
assertEquals(3, x.size());
assertEquals(3, A.rowCount());
assertEquals(5, A.columnCount());
assertEquals(2, A.get(0,3), epsilon);
assertEquals(3, A.get(1,0), epsilon);
Vector y = A.transposeMultiply(x);
assertEquals(5, y.size());
assertEquals(3*9, y.get(0), epsilon);
assertEquals(1*8+6*10, y.get(1), epsilon);
assertEquals(4*9, y.get(2), epsilon);
assertEquals(2*8+7*10, y.get(3), epsilon);
assertEquals(5*9, y.get(4), epsilon);
}
@Test
public void shouldMultiplyMatrixAndVector() {
Vector x = Vector.from(7,8,9);
Matrix A = Matrix.from(2,3,
1,2,3,
4,5,6);
assertEquals(3, x.size());
assertEquals(2, A.rowCount());
assertEquals(3, A.columnCount());
assertEquals(5, A.get(1,1), epsilon);
Vector y = A.mult(x);
assertEquals(2, y.size());
assertEquals(7*1+8*2+9*3, y.get(0), epsilon);
assertEquals(7*4+8*5+9*6, y.get(1), epsilon);
}
@Test
public void shouldMultiplyTransposedMatrixAndVector() {
Vector x = Vector.from(7,8,9);
Matrix A = Matrix.from(3,2,
1,2,
3,4,
5,6);
assertEquals(3, x.size());
assertEquals(3, A.rowCount());
assertEquals(2, A.columnCount());
assertEquals(4, A.get(1,1), epsilon);
Vector y = A.transposeMultiply(x);
assertEquals(2, y.size());
assertEquals(7*1+8*3+9*5, y.get(0), epsilon);
assertEquals(7*2+8*4+9*6, y.get(1), epsilon);
}
@Test(expected=AssertionError.class)
public void shouldFailWhenSizeDoesNotConform() {
Matrix.dense(3,2).mult(Vector.dense(3));
}
@Test
public void whenEmptyshouldNotHaveMaximum() {
assertEquals(Double.NaN, Matrix.dense(0,0).max(), epsilon);
}
@Test
public void whenEmptyshouldNotHaveMinimum() {
assertEquals(Double.NaN, Matrix.dense(0,0).min(), epsilon);
}
@Test
public void whenEmptyshouldNotHaveMean() {
assertEquals(Double.NaN, Matrix.dense(0,0).mean(), epsilon);
}
}