/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math;
import com.google.common.collect.Maps;
import org.apache.mahout.math.function.Functions;
import org.junit.Before;
import org.junit.Test;
import java.util.Map;
public final class TestMatrixView extends MahoutTestCase {
private static final int ROW = AbstractMatrix.ROW;
private static final int COL = AbstractMatrix.COL;
private final double[][] values = {{0.0, 1.1, 2.2}, {1.1, 2.2, 3.3},
{3.3, 4.4, 5.5}, {5.5, 6.6, 7.7}, {7.7, 8.8, 9.9}};
private Matrix test;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
int[] offset = {1, 1};
int[] card = {3, 2};
test = new MatrixView(new DenseMatrix(values), offset, card);
}
@Test
public void testCardinality() {
assertEquals("row cardinality", values.length - 2, test.rowSize());
assertEquals("col cardinality", values[0].length - 1, test.columnSize());
}
@Test
public void testCopy() {
Matrix copy = test.clone();
assertTrue("wrong class", copy instanceof MatrixView);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
test.getQuick(row, col), copy.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testGetQuick() {
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 1][col + 1], test.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testLike() {
Matrix like = test.like();
assertTrue("type", like instanceof DenseMatrix);
assertEquals("rows", test.rowSize(), like.rowSize());
assertEquals("columns", test.columnSize(), like.columnSize());
}
@Test
public void testLikeIntInt() {
Matrix like = test.like(4, 4);
assertTrue("type", like instanceof DenseMatrix);
assertEquals("rows", 4, like.rowSize());
assertEquals("columns", 4, like.columnSize());
}
@Test
public void testSetQuick() {
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
test.setQuick(row, col, 1.23);
assertEquals("value[" + row + "][" + col + ']', 1.23, test.getQuick(
row, col), EPSILON);
}
}
}
@Test
public void testSize() {
assertEquals("row size", values.length - 2, test.rowSize());
assertEquals("col size", values[0].length - 1, test.columnSize());
}
@Test
public void testViewPart() {
int[] offset = {1, 1};
int[] size = {2, 1};
Matrix view = test.viewPart(offset, size);
for (int row = 0; row < view.rowSize(); row++) {
for (int col = 0; col < view.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 2][col + 2], view.getQuick(row, col), EPSILON);
}
}
}
@Test(expected = IndexException.class)
public void testViewPartCardinality() {
int[] offset = {1, 1};
int[] size = {3, 3};
test.viewPart(offset, size);
}
@Test(expected = IndexException.class)
public void testViewPartIndexOver() {
int[] offset = {1, 1};
int[] size = {2, 2};
test.viewPart(offset, size);
}
@Test(expected = IndexException.class)
public void testViewPartIndexUnder() {
int[] offset = {-1, -1};
int[] size = {2, 2};
test.viewPart(offset, size);
}
@Test
public void testAssignDouble() {
test.assign(4.53);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']', 4.53, test.getQuick(
row, col), EPSILON);
}
}
}
@Test
public void testAssignDoubleArrayArray() {
test.assign(new double[3][2]);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']', 0.0, test.getQuick(row,
col), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testAssignDoubleArrayArrayCardinality() {
test.assign(new double[test.rowSize() + 1][test.columnSize()]);
}
@Test
public void testAssignMatrixBinaryFunction() {
test.assign(test, Functions.PLUS);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
2 * values[row + 1][col + 1], test.getQuick(row, col), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testAssignMatrixBinaryFunctionCardinality() {
test.assign(test.transpose(), Functions.PLUS);
}
@Test
public void testAssignMatrix() {
Matrix value = test.like();
value.assign(test);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
test.getQuick(row, col), value.getQuick(row, col), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testAssignMatrixCardinality() {
test.assign(test.transpose());
}
@Test
public void testAssignUnaryFunction() {
test.assign(Functions.NEGATE);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
-values[row + 1][col + 1], test.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testDivide() {
Matrix value = test.divide(4.53);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 1][col + 1] / 4.53, value.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testGet() {
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 1][col + 1], test.get(row, col), EPSILON);
}
}
}
@Test(expected = IndexException.class)
public void testGetIndexUnder() {
for (int row = -1; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
test.get(row, col);
}
}
}
@Test(expected = IndexException.class)
public void testGetIndexOver() {
for (int row = 0; row < test.rowSize() + 1; row++) {
for (int col = 0; col < test.columnSize(); col++) {
test.get(row, col);
}
}
}
@Test
public void testMinus() {
Matrix value = test.minus(test);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']', 0.0, value.getQuick(row, col), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testMinusCardinality() {
test.minus(test.transpose());
}
@Test
public void testPlusDouble() {
Matrix value = test.plus(4.53);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 1][col + 1] + 4.53, value.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testPlusMatrix() {
Matrix value = test.plus(test);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 1][col + 1] * 2, value.getQuick(row, col), EPSILON);
}
}
}
@Test(expected = CardinalityException.class)
public void testPlusMatrixCardinality() {
test.plus(test.transpose());
}
@Test(expected = IndexException.class)
public void testSetUnder() {
for (int row = -1; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
test.set(row, col, 1.23);
}
}
}
@Test(expected = IndexException.class)
public void testSetOver() {
for (int row = 0; row < test.rowSize() + 1; row++) {
for (int col = 0; col < test.columnSize(); col++) {
test.set(row, col, 1.23);
}
}
}
@Test
public void testTimesDouble() {
Matrix value = test.times(4.53);
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
values[row + 1][col + 1] * 4.53, value.getQuick(row, col), EPSILON);
}
}
}
@Test
public void testTimesMatrix() {
Matrix transpose = test.transpose();
Matrix value = test.times(transpose);
assertEquals("rows", test.rowSize(), value.rowSize());
assertEquals("cols", test.rowSize(), value.columnSize());
// TODO: check the math too, lazy
}
@Test(expected = CardinalityException.class)
public void testTimesMatrixCardinality() {
Matrix other = test.like(5, 8);
test.times(other);
}
@Test
public void testTranspose() {
Matrix transpose = test.transpose();
assertEquals("rows", test.columnSize(), transpose.rowSize());
assertEquals("cols", test.rowSize(), transpose.columnSize());
for (int row = 0; row < test.rowSize(); row++) {
for (int col = 0; col < test.columnSize(); col++) {
assertEquals("value[" + row + "][" + col + ']',
test.getQuick(row, col), transpose.getQuick(col, row), EPSILON);
}
}
}
@Test
public void testZSum() {
double sum = test.zSum();
assertEquals("zsum", 29.7, sum, EPSILON);
}
@Test
public void testAssignRow() {
double[] data = {2.1, 3.2};
test.assignRow(1, new DenseVector(data));
assertEquals("test[1][0]", 2.1, test.getQuick(1, 0), EPSILON);
assertEquals("test[1][1]", 3.2, test.getQuick(1, 1), EPSILON);
}
@Test(expected = CardinalityException.class)
public void testAssignRowCardinality() {
double[] data = {2.1, 3.2, 4.3};
test.assignRow(1, new DenseVector(data));
}
@Test
public void testAssignColumn() {
double[] data = {2.1, 3.2, 4.3};
test.assignColumn(1, new DenseVector(data));
assertEquals("test[0][1]", 2.1, test.getQuick(0, 1), EPSILON);
assertEquals("test[1][1]", 3.2, test.getQuick(1, 1), EPSILON);
assertEquals("test[2][1]", 4.3, test.getQuick(2, 1), EPSILON);
}
@Test(expected = CardinalityException.class)
public void testAssignColumnCardinality() {
double[] data = {2.1, 3.2};
test.assignColumn(1, new DenseVector(data));
}
@Test
public void testViewRow() {
Vector row = test.viewRow(1);
assertEquals("row size", 2, row.getNumNondefaultElements());
}
@Test(expected = IndexException.class)
public void testViewRowIndexUnder() {
test.viewRow(-1);
}
@Test(expected = IndexException.class)
public void testViewRowIndexOver() {
test.viewRow(5);
}
@Test
public void testViewColumn() {
Vector column = test.viewColumn(1);
assertEquals("row size", 3, column.getNumNondefaultElements());
int i = 0;
for (double x : new double[]{3.3, 5.5, 7.7}) {
assertEquals(x, column.get(i++), 0);
}
}
@Test(expected = IndexException.class)
public void testViewColumnIndexUnder() {
test.viewColumn(-1);
}
@Test(expected = IndexException.class)
public void testViewColumnIndexOver() {
test.viewColumn(5);
}
@Test
public void testLabelBindings() {
assertNull("row bindings", test.getRowLabelBindings());
assertNull("col bindings", test.getColumnLabelBindings());
Map<String, Integer> rowBindings = Maps.newHashMap();
rowBindings.put("Fee", 0);
rowBindings.put("Fie", 1);
test.setRowLabelBindings(rowBindings);
assertEquals("row", rowBindings, test.getRowLabelBindings());
Map<String, Integer> colBindings = Maps.newHashMap();
colBindings.put("Foo", 0);
colBindings.put("Bar", 1);
test.setColumnLabelBindings(colBindings);
assertEquals("row", rowBindings, test.getRowLabelBindings());
assertEquals("Fee", test.get(0, 1), test.get("Fee", "Bar"), EPSILON);
double[] newrow = {9, 8};
test.set("Fie", newrow);
assertEquals("FeeBar", test.get(0, 1), test.get("Fee", "Bar"), EPSILON);
}
@Test(expected = IllegalStateException.class)
public void testSettingLabelBindings() {
assertNull("row bindings", test.getRowLabelBindings());
assertNull("col bindings", test.getColumnLabelBindings());
test.set("Fee", "Foo", 1, 1, 9);
assertNotNull("row", test.getRowLabelBindings());
assertNotNull("row", test.getRowLabelBindings());
assertEquals("Fee", 1, test.getRowLabelBindings().get("Fee").intValue());
assertEquals("Foo", 1, test.getColumnLabelBindings().get("Foo").intValue());
assertEquals("FeeFoo", test.get(1, 1), test.get("Fee", "Foo"), EPSILON);
test.get("Fie", "Foe");
}
@Test
public void testLabelBindingSerialization() {
assertNull("row bindings", test.getRowLabelBindings());
assertNull("col bindings", test.getColumnLabelBindings());
Map<String, Integer> rowBindings = Maps.newHashMap();
rowBindings.put("Fee", 0);
rowBindings.put("Fie", 1);
rowBindings.put("Foe", 2);
test.setRowLabelBindings(rowBindings);
assertEquals("row", rowBindings, test.getRowLabelBindings());
Map<String, Integer> colBindings = Maps.newHashMap();
colBindings.put("Foo", 0);
colBindings.put("Bar", 1);
colBindings.put("Baz", 2);
test.setColumnLabelBindings(colBindings);
assertEquals("col", colBindings, test.getColumnLabelBindings());
}
}