/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.math.impls.matrix; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Supplier; import org.apache.ignite.ml.math.ExternalizeTest; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.ColumnIndexException; import org.apache.ignite.ml.math.exceptions.IndexException; import org.apache.ignite.ml.math.exceptions.RowIndexException; import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; import org.apache.ignite.ml.math.impls.vector.DenseLocalOffHeapVector; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.math.impls.vector.RandomVector; import org.apache.ignite.ml.math.impls.vector.SparseLocalVector; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Tests for {@link Matrix} implementations. */ public class MatrixImplementationsTest extends ExternalizeTest<Matrix> { /** */ private static final double DEFAULT_DELTA = 0.000000001d; /** */ private void consumeSampleMatrix(BiConsumer<Matrix, String> consumer) { new MatrixImplementationFixtures().consumeSampleMatrix(consumer); } /** */ @Test public void externalizeTest() { consumeSampleMatrix((m, desc) -> externalizeTest(m)); } /** */ @Test public void testLike() { consumeSampleMatrix((m, desc) -> { Class<? extends Matrix> cls = likeMatrixType(m); if (cls != null) { Matrix like = m.like(m.rowSize(), m.columnSize()); assertEquals("Wrong \"like\" matrix for " + desc + "; Unexpected rows.", like.rowSize(), m.rowSize()); assertEquals("Wrong \"like\" matrix for " + desc + "; Unexpected columns.", like.columnSize(), m.columnSize()); assertEquals("Wrong \"like\" matrix for " + desc + "; Unexpected class: " + like.getClass().toString(), cls, like.getClass()); return; } boolean expECaught = false; try { m.like(1, 1); } catch (UnsupportedOperationException uoe) { expECaught = true; } assertTrue("Expected exception was not caught for " + desc, expECaught); }); } /** */ @Test public void testCopy() { consumeSampleMatrix((m, desc) -> { Matrix cp = m.copy(); assertTrue("Incorrect copy for empty matrix " + desc, cp.equals(m)); if (!readOnly(m)) fillMatrix(m); cp = m.copy(); assertTrue("Incorrect copy for matrix " + desc, cp.equals(m)); }); } /** */ @Test public void testHaveLikeVector() throws InstantiationException, IllegalAccessException { for (Class<? extends Matrix> key : likeVectorTypesMap().keySet()) { Class<? extends Vector> val = likeVectorTypesMap().get(key); if (val == null && !ignore(key)) System.out.println("Missing test for implementation of likeMatrix for " + key.getSimpleName()); } } /** */ @Test public void testLikeVector() { consumeSampleMatrix((m, desc) -> { if (likeVectorTypesMap().containsKey(m.getClass())) { Vector likeVector = m.likeVector(m.columnSize()); assertNotNull(likeVector); assertEquals("Unexpected value for " + desc, likeVector.size(), m.columnSize()); return; } boolean expECaught = false; try { m.likeVector(1); } catch (UnsupportedOperationException uoe) { expECaught = true; } assertTrue("Expected exception was not caught for " + desc, expECaught); }); } /** */ @Test public void testAssignSingleElement() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; final double assignVal = Math.random(); m.assign(assignVal); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", assignVal, m.get(i, j), 0d); }); } /** */ @Test public void testAssignArray() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; double[][] data = new double[m.rowSize()][m.columnSize()]; for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) data[i][j] = Math.random(); m.assign(data); for (int i = 0; i < m.rowSize(); i++) { for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j], m.get(i, j), 0d); } }); } /** */ @Test public void testAssignFunction() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; m.assign((i, j) -> (double)(i * m.columnSize() + j)); for (int i = 0; i < m.rowSize(); i++) { for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", (double)(i * m.columnSize() + j), m.get(i, j), 0d); } }); } /** */ @Test public void testPlus() { consumeSampleMatrix((m, desc) -> { if (readOnly(m)) return; double[][] data = fillAndReturn(m); double plusVal = Math.random(); Matrix plus = m.plus(plusVal); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j] + plusVal, plus.get(i, j), 0d); }); } /** */ @Test public void testPlusMatrix() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; double[][] data = fillAndReturn(m); Matrix plus = m.plus(m); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j] * 2.0, plus.get(i, j), 0d); }); } /** */ @Test public void testMinusMatrix() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); Matrix minus = m.minus(m); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", 0.0, minus.get(i, j), 0d); }); } /** */ @Test public void testTimes() { consumeSampleMatrix((m, desc) -> { if (readOnly(m)) return; double[][] data = fillAndReturn(m); double timeVal = Math.random(); Matrix times = m.times(timeVal); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j] * timeVal, times.get(i, j), 0d); }); } /** */ @Test public void testTimesVector() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; double[][] data = fillAndReturn(m); double[] arr = fillArray(m.columnSize()); Vector times = m.times(new DenseLocalOnHeapVector(arr)); assertEquals("Unexpected vector size for " + desc, times.size(), m.rowSize()); for (int i = 0; i < m.rowSize(); i++) { double exp = 0.0; for (int j = 0; j < m.columnSize(); j++) exp += arr[j] * data[i][j]; assertEquals("Unexpected value for " + desc + " at " + i, times.get(i), exp, 0d); } testInvalidCardinality(() -> m.times(new DenseLocalOnHeapVector(m.columnSize() + 1)), desc); }); } /** */ @Test public void testTimesMatrix() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; double[][] data = fillAndReturn(m); double[] arr = fillArray(m.columnSize()); Matrix mult = new DenseLocalOnHeapMatrix(m.columnSize(), 1); mult.setColumn(0, arr); Matrix times = m.times(mult); assertEquals("Unexpected rows for " + desc, times.rowSize(), m.rowSize()); assertEquals("Unexpected cols for " + desc, times.columnSize(), 1); for (int i = 0; i < m.rowSize(); i++) { double exp = 0.0; for (int j = 0; j < m.columnSize(); j++) exp += arr[j] * data[i][j]; assertEquals("Unexpected value for " + desc + " at " + i, exp, times.get(i, 0), 0d); } testInvalidCardinality(() -> m.times(new DenseLocalOnHeapMatrix(m.columnSize() + 1, 1)), desc); }); } /** */ @Test public void testDivide() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; double[][] data = fillAndReturn(m); double divVal = Math.random(); Matrix divide = m.divide(divVal); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j] / divVal, divide.get(i, j), 0d); }); } /** */ @Test public void testTranspose() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); Matrix transpose = m.transpose(); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", m.get(i, j), transpose.get(j, i), 0d); }); } /** */ @Test public void testDeterminant() { consumeSampleMatrix((m, desc) -> { if (m.rowSize() != m.columnSize()) return; if (ignore(m.getClass())) return; double[][] doubles = fillIntAndReturn(m); if (m.rowSize() == 1) { assertEquals("Unexpected value " + desc, m.determinant(), doubles[0][0], 0d); return; } if (m.rowSize() == 2) { double det = doubles[0][0] * doubles[1][1] - doubles[0][1] * doubles[1][0]; assertEquals("Unexpected value " + desc, m.determinant(), det, 0d); return; } if (m.rowSize() > 512) return; // IMPL NOTE if row size >= 30000 it takes unacceptably long for normal test run. Matrix diagMtx = m.like(m.rowSize(), m.columnSize()); diagMtx.assign(0); for (int i = 0; i < m.rowSize(); i++) diagMtx.set(i, i, m.get(i, i)); double det = 1; for (int i = 0; i < diagMtx.rowSize(); i++) det *= diagMtx.get(i, i); try { assertEquals("Unexpected value " + desc, det, diagMtx.determinant(), DEFAULT_DELTA); } catch (Exception e) { System.out.println(desc); throw e; } }); } /** */ @Test public void testInverse() { consumeSampleMatrix((m, desc) -> { if (m.rowSize() != m.columnSize()) return; if (ignore(m.getClass())) return; if (m.rowSize() > 256) return; // IMPL NOTE this is for quicker test run. fillNonSingularMatrix(m); assertTrue("Unexpected zero determinant " + desc, Math.abs(m.determinant()) > 0d); Matrix inverse = m.inverse(); Matrix mult = m.times(inverse); final double delta = 0.001d; assertEquals("Unexpected determinant " + desc, 1d, mult.determinant(), delta); assertEquals("Unexpected top left value " + desc, 1d, mult.get(0, 0), delta); if (m.rowSize() == 1) return; assertEquals("Unexpected center value " + desc, 1d, mult.get(m.rowSize() / 2, m.rowSize() / 2), delta); assertEquals("Unexpected bottom right value " + desc, 1d, mult.get(m.rowSize() - 1, m.rowSize() - 1), delta); assertEquals("Unexpected top right value " + desc, 0d, mult.get(0, m.rowSize() - 1), delta); assertEquals("Unexpected bottom left value " + desc, 0d, mult.get(m.rowSize() - 1, 0), delta); }); } /** */ @Test public void testMap() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); m.map(x -> 10d); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", 10d, m.get(i, j), 0d); }); } /** */ @Test public void testMapMatrix() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; double[][] doubles = fillAndReturn(m); testMapMatrixWrongCardinality(m, desc); Matrix cp = m.copy(); m.map(cp, (m1, m2) -> m1 + m2); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", m.get(i, j), doubles[i][j] * 2, 0d); }); } /** */ @Test public void testViewRow() { consumeSampleMatrix((m, desc) -> { if (!readOnly(m)) fillMatrix(m); for (int i = 0; i < m.rowSize(); i++) { Vector vector = m.viewRow(i); assert vector != null; for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", m.get(i, j), vector.get(j), 0d); } }); } /** */ @Test public void testViewCol() { consumeSampleMatrix((m, desc) -> { if (!readOnly(m)) fillMatrix(m); for (int i = 0; i < m.columnSize(); i++) { Vector vector = m.viewColumn(i); assert vector != null; for (int j = 0; j < m.rowSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", m.get(j, i), vector.get(j), 0d); } }); } /** */ @Test public void testFoldRow() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); Vector foldRows = m.foldRows(Vector::sum); for (int i = 0; i < m.rowSize(); i++) { Double locSum = 0d; for (int j = 0; j < m.columnSize(); j++) locSum += m.get(i, j); assertEquals("Unexpected value for " + desc + " at " + i, foldRows.get(i), locSum, 0d); } }); } /** */ @Test public void testFoldCol() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); Vector foldCols = m.foldColumns(Vector::sum); for (int j = 0; j < m.columnSize(); j++) { Double locSum = 0d; for (int i = 0; i < m.rowSize(); i++) locSum += m.get(i, j); assertEquals("Unexpected value for " + desc + " at " + j, foldCols.get(j), locSum, 0d); } }); } /** */ @Test public void testSum() { consumeSampleMatrix((m, desc) -> { double[][] data = fillAndReturn(m); double sum = m.sum(); double rawSum = 0; for (double[] anArr : data) for (int j = 0; j < data[0].length; j++) rawSum += anArr[j]; assertEquals("Unexpected value for " + desc, rawSum, sum, 0d); }); } /** */ @Test public void testMax() { consumeSampleMatrix((m, desc) -> { double[][] doubles = fillAndReturn(m); double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) max = max < doubles[i][j] ? doubles[i][j] : max; assertEquals("Unexpected value for " + desc, m.maxValue(), max, 0d); }); } /** */ @Test public void testMin() { consumeSampleMatrix((m, desc) -> { double[][] doubles = fillAndReturn(m); double min = Double.MAX_VALUE; for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) min = min > doubles[i][j] ? doubles[i][j] : min; assertEquals("Unexpected value for " + desc, m.minValue(), min, 0d); }); } /** */ @Test public void testGetElement() { consumeSampleMatrix((m, desc) -> { if (!(readOnly(m))) fillMatrix(m); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) { final Matrix.Element e = m.getElement(i, j); final String details = desc + " at [" + i + "," + j + "]"; assertEquals("Unexpected element row " + details, i, e.row()); assertEquals("Unexpected element col " + details, j, e.column()); final double val = m.get(i, j); assertEquals("Unexpected value for " + details, val, e.get(), 0d); boolean expECaught = false; final double newVal = val * 2.0; try { e.set(newVal); } catch (UnsupportedOperationException uoe) { if (!(readOnly(m))) throw uoe; expECaught = true; } if (readOnly(m)) { if (!expECaught) fail("Expected exception was not caught for " + details); continue; } assertEquals("Unexpected value set for " + details, newVal, m.get(i, j), 0d); } }); } /** */ @Test public void testGetX() { consumeSampleMatrix((m, desc) -> { if (!(readOnly(m))) fillMatrix(m); for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at [" + i + "," + j + "]", m.get(i, j), m.getX(i, j), 0d); }); } /** */ @Test public void testGetMetaStorage() { consumeSampleMatrix((m, desc) -> assertNotNull("Null meta storage in " + desc, m.getMetaStorage())); } /** */ @Test public void testGuid() { consumeSampleMatrix((m, desc) -> assertNotNull("Null guid in " + desc, m.guid())); } /** */ @Test public void testSwapRows() { consumeSampleMatrix((m, desc) -> { if (readOnly(m)) return; double[][] doubles = fillAndReturn(m); final int swap_i = m.rowSize() == 1 ? 0 : 1; final int swap_j = 0; Matrix swap = m.swapRows(swap_i, swap_j); for (int col = 0; col < m.columnSize(); col++) { assertEquals("Unexpected value for " + desc + " at col " + col + ", swap_i " + swap_i, swap.get(swap_i, col), doubles[swap_j][col], 0d); assertEquals("Unexpected value for " + desc + " at col " + col + ", swap_j " + swap_j, swap.get(swap_j, col), doubles[swap_i][col], 0d); } testInvalidRowIndex(() -> m.swapRows(-1, 0), desc + " negative first swap index"); testInvalidRowIndex(() -> m.swapRows(0, -1), desc + " negative second swap index"); testInvalidRowIndex(() -> m.swapRows(m.rowSize(), 0), desc + " too large first swap index"); testInvalidRowIndex(() -> m.swapRows(0, m.rowSize()), desc + " too large second swap index"); }); } /** */ @Test public void testSwapColumns() { consumeSampleMatrix((m, desc) -> { if (readOnly(m)) return; double[][] doubles = fillAndReturn(m); final int swap_i = m.columnSize() == 1 ? 0 : 1; final int swap_j = 0; Matrix swap = m.swapColumns(swap_i, swap_j); for (int row = 0; row < m.rowSize(); row++) { assertEquals("Unexpected value for " + desc + " at row " + row + ", swap_i " + swap_i, swap.get(row, swap_i), doubles[row][swap_j], 0d); assertEquals("Unexpected value for " + desc + " at row " + row + ", swap_j " + swap_j, swap.get(row, swap_j), doubles[row][swap_i], 0d); } testInvalidColIndex(() -> m.swapColumns(-1, 0), desc + " negative first swap index"); testInvalidColIndex(() -> m.swapColumns(0, -1), desc + " negative second swap index"); testInvalidColIndex(() -> m.swapColumns(m.columnSize(), 0), desc + " too large first swap index"); testInvalidColIndex(() -> m.swapColumns(0, m.columnSize()), desc + " too large second swap index"); }); } /** */ @Test public void testSetRow() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); int rowIdx = m.rowSize() / 2; double[] newValues = fillArray(m.columnSize()); m.setRow(rowIdx, newValues); for (int col = 0; col < m.columnSize(); col++) assertEquals("Unexpected value for " + desc + " at " + col, newValues[col], m.get(rowIdx, col), 0d); testInvalidCardinality(() -> m.setRow(rowIdx, new double[m.columnSize() + 1]), desc); }); } /** */ @Test public void testSetColumn() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); int colIdx = m.columnSize() / 2; double[] newValues = fillArray(m.rowSize()); m.setColumn(colIdx, newValues); for (int row = 0; row < m.rowSize(); row++) assertEquals("Unexpected value for " + desc + " at " + row, newValues[row], m.get(row, colIdx), 0d); testInvalidCardinality(() -> m.setColumn(colIdx, new double[m.rowSize() + 1]), desc); }); } /** */ @Test public void testViewPart() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); int rowOff = m.rowSize() < 3 ? 0 : 1; int rows = m.rowSize() < 3 ? 1 : m.rowSize() - 2; int colOff = m.columnSize() < 3 ? 0 : 1; int cols = m.columnSize() < 3 ? 1 : m.columnSize() - 2; Matrix view1 = m.viewPart(rowOff, rows, colOff, cols); Matrix view2 = m.viewPart(new int[] {rowOff, colOff}, new int[] {rows, cols}); String details = desc + " view [" + rowOff + ", " + rows + ", " + colOff + ", " + cols + "]"; for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) { assertEquals("Unexpected view1 value for " + details + " at (" + i + "," + j + ")", m.get(i + rowOff, j + colOff), view1.get(i, j), 0d); assertEquals("Unexpected view2 value for " + details + " at (" + i + "," + j + ")", m.get(i + rowOff, j + colOff), view2.get(i, j), 0d); } }); } /** */ @Test public void testDensity() { consumeSampleMatrix((m, desc) -> { if (!readOnly(m)) fillMatrix(m); assertTrue("Unexpected density with threshold 0 for " + desc, m.density(0.0)); assertFalse("Unexpected density with threshold 1 for " + desc, m.density(1.0)); }); } /** */ @Test public void testMaxAbsRowSumNorm() { consumeSampleMatrix((m, desc) -> { if (!readOnly(m)) fillMatrix(m); assertEquals("Unexpected value for " + desc, maxAbsRowSumNorm(m), m.maxAbsRowSumNorm(), 0d); }); } /** */ @Test public void testAssignRow() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); int rowIdx = m.rowSize() / 2; double[] newValues = fillArray(m.columnSize()); m.assignRow(rowIdx, new DenseLocalOnHeapVector(newValues)); for (int col = 0; col < m.columnSize(); col++) assertEquals("Unexpected value for " + desc + " at " + col, newValues[col], m.get(rowIdx, col), 0d); testInvalidCardinality(() -> m.assignRow(rowIdx, new DenseLocalOnHeapVector(m.columnSize() + 1)), desc); }); } /** */ @Test public void testAssignColumn() { consumeSampleMatrix((m, desc) -> { if (ignore(m.getClass())) return; fillMatrix(m); int colIdx = m.columnSize() / 2; double[] newValues = fillArray(m.rowSize()); m.assignColumn(colIdx, new DenseLocalOnHeapVector(newValues)); for (int row = 0; row < m.rowSize(); row++) assertEquals("Unexpected value for " + desc + " at " + row, newValues[row], m.get(row, colIdx), 0d); }); } /** */ private double[] fillArray(int len) { double[] newValues = new double[len]; for (int i = 0; i < newValues.length; i++) newValues[i] = newValues.length - i; return newValues; } /** */ private double maxAbsRowSumNorm(Matrix m) { double max = 0.0; for (int x = 0; x < m.rowSize(); x++) { double sum = 0; for (int y = 0; y < m.columnSize(); y++) sum += Math.abs(m.getX(x, y)); if (sum > max) max = sum; } return max; } /** */ private void testInvalidRowIndex(Supplier<Matrix> supplier, String desc) { try { supplier.get(); } catch (RowIndexException | IndexException ie) { return; } fail("Expected exception was not caught for " + desc); } /** */ private void testInvalidColIndex(Supplier<Matrix> supplier, String desc) { try { supplier.get(); } catch (ColumnIndexException | IndexException ie) { return; } fail("Expected exception was not caught for " + desc); } /** */ private void testMapMatrixWrongCardinality(Matrix m, String desc) { for (int rowDelta : new int[] {-1, 0, 1}) for (int colDelta : new int[] {-1, 0, 1}) { if (rowDelta == 0 && colDelta == 0) continue; int rowNew = m.rowSize() + rowDelta; int colNew = m.columnSize() + colDelta; if (rowNew < 1 || colNew < 1) continue; testInvalidCardinality(() -> m.map(new DenseLocalOnHeapMatrix(rowNew, colNew), (m1, m2) -> m1 + m2), desc + " wrong cardinality when mapping to size " + rowNew + "x" + colNew); } } /** */ private void testInvalidCardinality(Supplier<Object> supplier, String desc) { try { supplier.get(); } catch (CardinalityException ce) { return; } fail("Expected exception was not caught for " + desc); } /** */ private boolean readOnly(Matrix m) { return m instanceof RandomMatrix; } /** */ private double[][] fillIntAndReturn(Matrix m) { double[][] data = new double[m.rowSize()][m.columnSize()]; if (readOnly(m)) { for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) data[i][j] = m.get(i, j); } else { for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) data[i][j] = i * m.rowSize() + j + 1; m.assign(data); } return data; } /** */ private double[][] fillAndReturn(Matrix m) { double[][] data = new double[m.rowSize()][m.columnSize()]; if (readOnly(m)) { for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) data[i][j] = m.get(i, j); } else { for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) data[i][j] = -0.5d + Math.random(); m.assign(data); } return data; } /** */ private void fillNonSingularMatrix(Matrix m) { for (int i = 0; i < m.rowSize(); i++) { m.set(i, i, 10); for (int j = 0; j < m.columnSize(); j++) if (j != i) m.set(i, j, 0.01d); } } /** */ private void fillMatrix(Matrix m) { for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) m.set(i, j, Math.random()); } /** Ignore test for given matrix type. */ private boolean ignore(Class<? extends Matrix> clazz) { List<Class<? extends Matrix>> ignoredClasses = Arrays.asList(RandomMatrix.class, PivotedMatrixView.class, MatrixView.class, FunctionMatrix.class, TransposedMatrixView.class); for (Class<? extends Matrix> ignoredClass : ignoredClasses) if (ignoredClass.isAssignableFrom(clazz)) return true; return false; } /** */ private Class<? extends Matrix> likeMatrixType(Matrix m) { for (Class<? extends Matrix> clazz : likeTypesMap().keySet()) if (clazz.isAssignableFrom(m.getClass())) return likeTypesMap().get(clazz); return null; } /** */ private static Map<Class<? extends Matrix>, Class<? extends Vector>> likeVectorTypesMap() { return new LinkedHashMap<Class<? extends Matrix>, Class<? extends Vector>>() {{ put(DenseLocalOnHeapMatrix.class, DenseLocalOnHeapVector.class); put(DenseLocalOffHeapMatrix.class, DenseLocalOffHeapVector.class); put(RandomMatrix.class, RandomVector.class); put(SparseLocalOnHeapMatrix.class, SparseLocalVector.class); put(DenseLocalOnHeapMatrix.class, DenseLocalOnHeapVector.class); put(DiagonalMatrix.class, DenseLocalOnHeapVector.class); // IMPL NOTE per fixture // IMPL NOTE check for presence of all implementations here will be done in testHaveLikeMatrix via Fixture }}; } /** */ private static Map<Class<? extends Matrix>, Class<? extends Matrix>> likeTypesMap() { return new LinkedHashMap<Class<? extends Matrix>, Class<? extends Matrix>>() {{ put(DenseLocalOnHeapMatrix.class, DenseLocalOnHeapMatrix.class); put(DenseLocalOffHeapMatrix.class, DenseLocalOffHeapMatrix.class); put(RandomMatrix.class, RandomMatrix.class); put(SparseLocalOnHeapMatrix.class, SparseLocalOnHeapMatrix.class); put(DenseLocalOnHeapMatrix.class, DenseLocalOnHeapMatrix.class); put(DiagonalMatrix.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture put(FunctionMatrix.class, FunctionMatrix.class); // IMPL NOTE check for presence of all implementations here will be done in testHaveLikeMatrix via Fixture }}; } }