/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.matrix;
import java.util.Arrays;
import org.apache.ignite.ml.math.Matrix;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/** */
public class PivotedMatrixViewConstructorTest {
/** */
@Test
public void invalidArgsTest() {
Matrix m = new DenseLocalOnHeapMatrix(1, 1);
int[] pivot = new int[] {0};
DenseLocalOnHeapMatrixConstructorTest.verifyAssertionError(() -> new PivotedMatrixView(null),
"Null parent matrix.");
DenseLocalOnHeapMatrixConstructorTest.verifyAssertionError(() -> new PivotedMatrixView(null, pivot),
"Null parent matrix, with pivot.");
DenseLocalOnHeapMatrixConstructorTest.verifyAssertionError(() -> new PivotedMatrixView(m, null),
"Null pivot.");
DenseLocalOnHeapMatrixConstructorTest.verifyAssertionError(() -> new PivotedMatrixView(m, null, pivot),
"Null row pivot.");
DenseLocalOnHeapMatrixConstructorTest.verifyAssertionError(() -> new PivotedMatrixView(m, pivot, null),
"Null col pivot.");
}
/** */
@Test
public void basicTest() {
Matrix m = new DenseLocalOnHeapMatrix(2, 2);
int[] pivot = new int[] {0, 1};
PivotedMatrixView view = new PivotedMatrixView(m, pivot);
assertEquals("Rows in view.", m.rowSize(), view.rowSize());
assertEquals("Cols in view.", m.columnSize(), view.columnSize());
assertTrue("Row pivot array in view.", Arrays.equals(pivot, view.rowPivot()));
assertTrue("Col pivot array in view.", Arrays.equals(pivot, view.columnPivot()));
Assert.assertEquals("Base matrix in view.", m, view.getBaseMatrix());
assertEquals("Row pivot value in view.", 0, view.rowPivot(0));
assertEquals("Col pivot value in view.", 0, view.columnPivot(0));
assertEquals("Row unpivot value in view.", 0, view.rowUnpivot(0));
assertEquals("Col unpivot value in view.", 0, view.columnUnpivot(0));
Matrix swap = view.swap(1, 1);
for (int row = 0; row < view.rowSize(); row++)
for (int col = 0; col < view.columnSize(); col++)
assertEquals("Unexpected swap value set at (" + row + "," + col + ").",
view.get(row, col), swap.get(row, col), 0d);
//noinspection EqualsWithItself
assertTrue("View is expected to be equal to self.", view.equals(view));
//noinspection ObjectEqualsNull
assertFalse("View is expected to be not equal to null.", view.equals(null));
}
/** */
@Test
public void pivotTest() {
int[] pivot = new int[] {2, 1, 0, 3};
for (Matrix m : new Matrix[] {
new DenseLocalOnHeapMatrix(3, 3),
new DenseLocalOnHeapMatrix(3, 4), new DenseLocalOnHeapMatrix(4, 3)})
pivotTest(m, pivot);
}
/** */
private void pivotTest(Matrix parent, int[] pivot) {
for (int row = 0; row < parent.rowSize(); row++)
for (int col = 0; col < parent.columnSize(); col++)
parent.set(row, col, row * parent.columnSize() + col + 1);
Matrix view = new PivotedMatrixView(parent, pivot);
int rows = parent.rowSize();
int cols = parent.columnSize();
assertEquals("Rows in view.", rows, view.rowSize());
assertEquals("Cols in view.", cols, view.columnSize());
for (int row = 0; row < rows; row++)
for (int col = 0; col < cols; col++)
assertEquals("Unexpected value at " + row + "x" + col,
parent.get(pivot[row], pivot[col]), view.get(row, col), 0d);
int min = rows < cols ? rows : cols;
for (int idx = 0; idx < min; idx++)
view.set(idx, idx, 0d);
for (int idx = 0; idx < min; idx++)
assertEquals("Unexpected value set at " + idx,
0d, parent.get(pivot[idx], pivot[idx]), 0d);
}
}