/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.vector;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.IndexException;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Tests for {@link MatrixVectorView}.
*/
public class MatrixVectorViewTest {
/** */
private static final String UNEXPECTED_VALUE = "Unexpected value";
/** */
private static final int SMALL_SIZE = 3;
/** */
private static final int IMPOSSIBLE_SIZE = -1;
/** */
private Matrix parent;
/** */
@Before
public void setup() {
parent = newMatrix(SMALL_SIZE, SMALL_SIZE);
}
/** */
@Test
public void testDiagonal() {
Vector vector = parent.viewDiagonal();
for (int i = 0; i < SMALL_SIZE; i++)
assertView(i, i, vector, i);
}
/** */
@Test
public void testRow() {
for (int i = 0; i < SMALL_SIZE; i++) {
Vector viewRow = parent.viewRow(i);
for (int j = 0; j < SMALL_SIZE; j++)
assertView(i, j, viewRow, j);
}
}
/** */
@Test
public void testCols() {
for (int i = 0; i < SMALL_SIZE; i++) {
Vector viewCol = parent.viewColumn(i);
for (int j = 0; j < SMALL_SIZE; j++)
assertView(j, i, viewCol, j);
}
}
/** */
@Test
public void basicTest() {
for (int rowSize : new int[] {1, 2, 3, 4})
for (int colSize : new int[] {1, 2, 3, 4})
for (int row = 0; row < rowSize; row++)
for (int col = 0; col < colSize; col++)
for (int rowStride = 0; rowStride < rowSize; rowStride++)
for (int colStride = 0; colStride < colSize; colStride++)
if (rowStride != 0 || colStride != 0)
assertMatrixVectorView(newMatrix(rowSize, colSize), row, col, rowStride, colStride);
}
/** */
@Test(expected = AssertionError.class)
public void parentNullTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(null, 1, 1, 1, 1).size());
}
/** */
@Test(expected = IndexException.class)
public void rowNegativeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, -1, 1, 1, 1).size());
}
/** */
@Test(expected = IndexException.class)
public void colNegativeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, -1, 1, 1).size());
}
/** */
@Test(expected = IndexException.class)
public void rowTooLargeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, parent.rowSize() + 1, 1, 1, 1).size());
}
/** */
@Test(expected = IndexException.class)
public void colTooLargeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, parent.columnSize() + 1, 1, 1).size());
}
/** */
@Test(expected = AssertionError.class)
public void rowStrideNegativeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, 1, -1, 1).size());
}
/** */
@Test(expected = AssertionError.class)
public void colStrideNegativeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, 1, 1, -1).size());
}
/** */
@Test(expected = AssertionError.class)
public void rowStrideTooLargeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, 1, parent.rowSize() + 1, 1).size());
}
/** */
@Test(expected = AssertionError.class)
public void colStrideTooLargeTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, 1, 1, parent.columnSize() + 1).size());
}
/** */
@Test(expected = AssertionError.class)
public void bothStridesZeroTest() {
//noinspection ConstantConditions
assertEquals(IMPOSSIBLE_SIZE,
new MatrixVectorView(parent, 1, 1, 0, 0).size());
}
/** */
private void assertMatrixVectorView(Matrix parent, int row, int col, int rowStride, int colStride) {
MatrixVectorView view = new MatrixVectorView(parent, row, col, rowStride, colStride);
String desc = "parent [" + parent.rowSize() + "x" + parent.columnSize() + "], view ["
+ row + "x" + col + "], strides [" + rowStride + ", " + colStride + "]";
final int size = view.size();
final int sizeByRows = rowStride == 0 ? IMPOSSIBLE_SIZE : (parent.rowSize() - row) / rowStride;
final int sizeByCols = colStride == 0 ? IMPOSSIBLE_SIZE : (parent.columnSize() - col) / colStride;
assertTrue("Size " + size + " differs from expected for " + desc,
size == sizeByRows || size == sizeByCols);
for (int idx = 0; idx < size; idx++) {
final int rowIdx = row + idx * rowStride;
final int colIdx = col + idx * colStride;
assertEquals(UNEXPECTED_VALUE + " at view index " + idx + desc,
parent.get(rowIdx, colIdx), view.get(idx), 0d);
}
}
/** */
private Matrix newMatrix(int rowSize, int colSize) {
Matrix res = new DenseLocalOnHeapMatrix(rowSize, colSize);
for (int i = 0; i < res.rowSize(); i++)
for (int j = 0; j < res.columnSize(); j++)
res.set(i, j, i * res.rowSize() + j);
return res;
}
/** */
private void assertView(int row, int col, Vector view, int viewIdx) {
assertValue(row, col, view, viewIdx);
parent.set(row, col, parent.get(row, col) + 1);
assertValue(row, col, view, viewIdx);
view.set(viewIdx, view.get(viewIdx) + 2);
assertValue(row, col, view, viewIdx);
}
/** */
private void assertValue(int row, int col, Vector view, int viewIdx) {
assertEquals(UNEXPECTED_VALUE + " at row " + row + " col " + col, parent.get(row, col), view.get(viewIdx), 0d);
}
}