/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math; import com.google.common.collect.Maps; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.function.Functions; import org.apache.mahout.math.function.VectorFunction; import org.junit.Before; import org.junit.Test; import java.util.Iterator; import java.util.Map; import java.util.Random; public abstract class MatrixTest extends MahoutTestCase { protected static final int ROW = AbstractMatrix.ROW; protected static final int COL = AbstractMatrix.COL; private final double[][] values = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}}; private final double[] vectorAValues = {1.0 / 1.1, 2.0 / 1.1}; private Matrix test; @Override @Before public void setUp() throws Exception { super.setUp(); test = matrixFactory(values); } public abstract Matrix matrixFactory(double[][] values); @Test public void testCardinality() { assertEquals("row cardinality", values.length, test.rowSize()); assertEquals("col cardinality", values[0].length, test.columnSize()); } @Test public void testCopy() { Matrix copy = test.clone(); assertSame("wrong class", copy.getClass(), test.getClass()); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', test.getQuick(row, col), copy.getQuick(row, col), EPSILON); } } } @Test public void testIterate() { Iterator<MatrixSlice> it = test.iterator(); MatrixSlice m; while(it.hasNext() && (m = it.next()) != null) { Vector v = m.vector(); Vector w = test instanceof SparseColumnMatrix ? test.viewColumn(m.index()) : test.viewRow(m.index()); assertEquals("iterator: " + v + ", randomAccess: " + w, v, w); } } @Test public void testGetQuick() { for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row][col], test .getQuick(row, col), EPSILON); } } } @Test public void testLike() { Matrix like = test.like(); assertSame("type", like.getClass(), test.getClass()); assertEquals("rows", test.rowSize(), like.rowSize()); assertEquals("columns", test.columnSize(), like.columnSize()); } @Test public void testLikeIntInt() { Matrix like = test.like(4, 4); assertSame("type", like.getClass(), test.getClass()); assertEquals("rows", 4, like.rowSize()); assertEquals("columns", 4, like.columnSize()); } @Test public void testSetQuick() { for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { test.setQuick(row, col, 1.23); assertEquals("value[" + row + "][" + col + ']', 1.23, test.getQuick( row, col), EPSILON); } } } @Test public void testSize() { int[] c = test.getNumNondefaultElements(); assertEquals("row size", values.length, c[ROW]); assertEquals("col size", values[0].length, c[COL]); } @Test public void testViewPart() { int[] offset = {1, 1}; int[] size = {2, 1}; Matrix view = test.viewPart(offset, size); assertEquals(2, view.rowSize()); assertEquals(1, view.columnSize()); for (int row = 0; row < view.rowSize(); row++) { for (int col = 0; col < view.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row + 1][col + 1], view.get(row, col), EPSILON); } } } @Test(expected = IndexException.class) public void testViewPartCardinality() { int[] offset = {1, 1}; int[] size = {3, 3}; test.viewPart(offset, size); } @Test(expected = IndexException.class) public void testViewPartIndexOver() { int[] offset = {1, 1}; int[] size = {2, 2}; test.viewPart(offset, size); } @Test(expected = IndexException.class) public void testViewPartIndexUnder() { int[] offset = {-1, -1}; int[] size = {2, 2}; test.viewPart(offset, size); } @Test public void testAssignDouble() { test.assign(4.53); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', 4.53, test.getQuick( row, col), EPSILON); } } } @Test public void testAssignDoubleArrayArray() { test.assign(new double[3][2]); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', 0.0, test.getQuick(row, col), EPSILON); } } } @Test(expected = CardinalityException.class) public void testAssignDoubleArrayArrayCardinality() { test.assign(new double[test.rowSize() + 1][test.columnSize()]); } @Test public void testAssignMatrixBinaryFunction() { test.assign(test, Functions.PLUS); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', 2 * values[row][col], test.getQuick(row, col), EPSILON); } } } @Test(expected = CardinalityException.class) public void testAssignMatrixBinaryFunctionCardinality() { test.assign(test.transpose(), Functions.PLUS); } @Test public void testAssignMatrix() { Matrix value = test.like(); value.assign(test); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', test.getQuick(row, col), value.getQuick(row, col), EPSILON); } } } @Test(expected = CardinalityException.class) public void testAssignMatrixCardinality() { test.assign(test.transpose()); } @Test public void testAssignUnaryFunction() { test.assign(Functions.mult(-1)); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', -values[row][col], test .getQuick(row, col), EPSILON); } } } @Test public void testRowView() { assertEquals(test.columnSize(), test.viewRow(1).size()); assertEquals(test.columnSize(), test.viewRow(2).size()); Random gen = RandomUtils.getRandom(); for (int row = 0; row < test.rowSize(); row++) { int j = gen.nextInt(test.columnSize()); double old = test.get(row, j); double v = gen.nextGaussian(); test.viewRow(row).set(j, v); assertEquals(v, test.get(row, j), 0); assertEquals(v, test.viewRow(row).get(j), 0); test.set(row, j, old); assertEquals(old, test.get(row, j), 0); assertEquals(old, test.viewRow(row).get(j), 0); } } @Test public void testColumnView() { assertEquals(test.rowSize(), test.viewColumn(0).size()); assertEquals(test.rowSize(), test.viewColumn(1).size()); Random gen = RandomUtils.getRandom(); for (int col = 0; col < test.columnSize(); col++) { int j = gen.nextInt(test.columnSize()); double old = test.get(col, j); double v = gen.nextGaussian(); test.viewColumn(col).set(j, v); assertEquals(v, test.get(j, col), 0); assertEquals(v, test.viewColumn(col).get(j), 0); test.set(j, col, old); assertEquals(old, test.get(j, col), 0); assertEquals(old, test.viewColumn(col).get(j), 0); } } @Test public void testAggregateRows() { Vector v = test.aggregateRows(new VectorFunction() { @Override public double apply(Vector v) { return v.zSum(); } }); for (int i = 0; i < test.numRows(); i++) { assertEquals(test.viewRow(i).zSum(), v.get(i), EPSILON); } } @Test public void testAggregateCols() { Vector v = test.aggregateColumns(new VectorFunction() { @Override public double apply(Vector v) { return v.zSum(); } }); for (int i = 0; i < test.numCols(); i++) { assertEquals(test.viewColumn(i).zSum(), v.get(i), EPSILON); } } @Test public void testAggregate() { double total = test.aggregate(Functions.PLUS, Functions.IDENTITY); assertEquals(test.aggregateRows(new VectorFunction() { @Override public double apply(Vector v) { return v.zSum(); } }).zSum(), total, EPSILON); } @Test public void testDivide() { Matrix value = test.divide(4.53); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row][col] / 4.53, value.getQuick(row, col), EPSILON); } } } @Test public void testGet() { for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row][col], test .get(row, col), EPSILON); } } } @Test(expected = IndexException.class) public void testGetIndexUnder() { for (int row = -1; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { test.get(row, col); } } } @Test(expected = IndexException.class) public void testGetIndexOver() { for (int row = 0; row < test.rowSize() + 1; row++) { for (int col = 0; col < test.columnSize(); col++) { test.get(row, col); } } } @Test public void testMinus() { Matrix value = test.minus(test); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', 0.0, value.getQuick( row, col), EPSILON); } } } @Test(expected = CardinalityException.class) public void testMinusCardinality() { test.minus(test.transpose()); } @Test public void testPlusDouble() { Matrix value = test.plus(4.53); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row][col] + 4.53, value.getQuick(row, col), EPSILON); } } } @Test public void testPlusMatrix() { Matrix value = test.plus(test); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row][col] * 2, value.getQuick(row, col), EPSILON); } } } @Test(expected = CardinalityException.class) public void testPlusMatrixCardinality() { test.plus(test.transpose()); } @Test(expected = IndexException.class) public void testSetUnder() { for (int row = -1; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { test.set(row, col, 1.23); } } } @Test(expected = IndexException.class) public void testSetOver() { for (int row = 0; row < test.rowSize() + 1; row++) { for (int col = 0; col < test.columnSize(); col++) { test.set(row, col, 1.23); } } } @Test public void testTimesDouble() { Matrix value = test.times(4.53); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', values[row][col] * 4.53, value.getQuick(row, col), EPSILON); } } } @Test public void testTimesMatrix() { Matrix transpose = test.transpose(); Matrix value = test.times(transpose); assertEquals("rows", test.rowSize(), value.rowSize()); assertEquals("cols", test.rowSize(), value.columnSize()); Matrix expected = new DenseMatrix(new double[][]{{5.0, 11.0, 17.0}, {11.0, 25.0, 39.0}, {17.0, 39.0, 61.0}}).times(1.21); for (int i = 0; i < expected.numCols(); i++) { for (int j = 0; j < expected.numRows(); j++) { assertTrue("Matrix times transpose not correct: " + i + ", " + j + "\nexpected:\n\t" + expected + "\nactual:\n\t" + value, Math.abs(expected.get(i, j) - value.get(i, j)) < 1.0e-12); } } Matrix timestest = new DenseMatrix(10, 1); /* will throw ArrayIndexOutOfBoundsException exception without MAHOUT-26 */ timestest.transpose().times(timestest); } @Test(expected = CardinalityException.class) public void testTimesVector() { Vector vectorA = new DenseVector(vectorAValues); Vector testTimesVectorA = test.times(vectorA); Vector expected = new DenseVector(new double[]{5.0, 11.0, 17.0}); assertTrue("Matrix times vector not equals: " + vectorA + " != " + testTimesVectorA, expected.minus(testTimesVectorA).norm(2) < 1.0e-12); test.times(testTimesVectorA); } @Test public void testTimesSquaredTimesVector() { Vector vectorA = new DenseVector(vectorAValues); Vector ttA = test.timesSquared(vectorA); Vector ttASlow = test.transpose().times(test.times(vectorA)); assertTrue("M'Mv != M.timesSquared(v): " + ttA + " != " + ttASlow, ttASlow.minus(ttA).norm(2) < 1.0e-12); } @Test(expected = CardinalityException.class) public void testTimesMatrixCardinality() { Matrix other = test.like(5, 8); test.times(other); } @Test public void testTranspose() { Matrix transpose = test.transpose(); assertEquals("rows", test.columnSize(), transpose.rowSize()); assertEquals("cols", test.rowSize(), transpose.columnSize()); for (int row = 0; row < test.rowSize(); row++) { for (int col = 0; col < test.columnSize(); col++) { assertEquals("value[" + row + "][" + col + ']', test.getQuick(row, col), transpose.getQuick(col, row), EPSILON); } } } @Test public void testZSum() { double sum = test.zSum(); assertEquals("zsum", 23.1, sum, EPSILON); } @Test public void testAssignRow() { double[] data = {2.1, 3.2}; test.assignRow(1, new DenseVector(data)); assertEquals("test[1][0]", 2.1, test.getQuick(1, 0), EPSILON); assertEquals("test[1][1]", 3.2, test.getQuick(1, 1), EPSILON); } @Test(expected = CardinalityException.class) public void testAssignRowCardinality() { double[] data = {2.1, 3.2, 4.3}; test.assignRow(1, new DenseVector(data)); } @Test public void testAssignColumn() { double[] data = {2.1, 3.2, 4.3}; test.assignColumn(1, new DenseVector(data)); assertEquals("test[0][1]", 2.1, test.getQuick(0, 1), EPSILON); assertEquals("test[1][1]", 3.2, test.getQuick(1, 1), EPSILON); assertEquals("test[2][1]", 4.3, test.getQuick(2, 1), EPSILON); } @Test(expected = CardinalityException.class) public void testAssignColumnCardinality() { double[] data = {2.1, 3.2}; test.assignColumn(1, new DenseVector(data)); } @Test public void testViewRow() { Vector row = test.viewRow(1); assertEquals("row size", 2, row.getNumNondefaultElements()); } @Test(expected = IndexException.class) public void testViewRowIndexUnder() { test.viewRow(-1); } @Test(expected = IndexException.class) public void testViewRowIndexOver() { test.viewRow(5); } @Test public void testViewColumn() { Vector column = test.viewColumn(1); assertEquals("row size", 3, column.getNumNondefaultElements()); } @Test(expected = IndexException.class) public void testViewColumnIndexUnder() { test.viewColumn(-1); } @Test(expected = IndexException.class) public void testViewColumnIndexOver() { test.viewColumn(5); } @Test public void testDeterminant() { Matrix m = matrixFactory(new double[][]{{1, 3, 4}, {5, 2, 3}, {1, 4, 2}}); assertEquals("determinant", 43.0, m.determinant(), EPSILON); } @Test public void testLabelBindings() { Matrix m = matrixFactory(new double[][]{{1, 3, 4}, {5, 2, 3}, {1, 4, 2}}); assertNull("row bindings", m.getRowLabelBindings()); assertNull("col bindings", m.getColumnLabelBindings()); Map<String, Integer> rowBindings = Maps.newHashMap(); rowBindings.put("Fee", 0); rowBindings.put("Fie", 1); rowBindings.put("Foe", 2); m.setRowLabelBindings(rowBindings); assertEquals("row", rowBindings, m.getRowLabelBindings()); Map<String, Integer> colBindings = Maps.newHashMap(); colBindings.put("Foo", 0); colBindings.put("Bar", 1); colBindings.put("Baz", 2); m.setColumnLabelBindings(colBindings); assertEquals("row", rowBindings, m.getRowLabelBindings()); assertEquals("Fee", m.get(0, 1), m.get("Fee", "Bar"), EPSILON); double[] newrow = {9, 8, 7}; m.set("Foe", newrow); assertEquals("FeeBaz", m.get(0, 2), m.get("Fee", "Baz"), EPSILON); } @Test(expected = IllegalStateException.class) public void testSettingLabelBindings() { Matrix m = matrixFactory(new double[][]{{1, 3, 4}, {5, 2, 3}, {1, 4, 2}}); assertNull("row bindings", m.getRowLabelBindings()); assertNull("col bindings", m.getColumnLabelBindings()); m.set("Fee", "Foo", 1, 2, 9); assertNotNull("row", m.getRowLabelBindings()); assertNotNull("row", m.getRowLabelBindings()); assertEquals("Fee", 1, m.getRowLabelBindings().get("Fee").intValue()); assertEquals("Fee", 2, m.getColumnLabelBindings().get("Foo").intValue()); assertEquals("FeeFoo", m.get(1, 2), m.get("Fee", "Foo"), EPSILON); m.get("Fie", "Foe"); } @Test public void testLabelBindingSerialization() { Matrix m = matrixFactory(new double[][]{{1, 3, 4}, {5, 2, 3}, {1, 4, 2}}); assertNull("row bindings", m.getRowLabelBindings()); assertNull("col bindings", m.getColumnLabelBindings()); Map<String, Integer> rowBindings = Maps.newHashMap(); rowBindings.put("Fee", 0); rowBindings.put("Fie", 1); rowBindings.put("Foe", 2); m.setRowLabelBindings(rowBindings); assertEquals("row", rowBindings, m.getRowLabelBindings()); Map<String, Integer> colBindings = Maps.newHashMap(); colBindings.put("Foo", 0); colBindings.put("Bar", 1); colBindings.put("Baz", 2); m.setColumnLabelBindings(colBindings); assertEquals("col", colBindings, m.getColumnLabelBindings()); } }