/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.math.impls.vector; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOffHeapMatrix; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.matrix.RandomMatrix; import org.apache.ignite.ml.math.impls.matrix.SparseLocalOnHeapMatrix; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; /** Tests for methods of Vector that involve Matrix. */ public class VectorToMatrixTest { /** */ private static final Map<Class<? extends Vector>, Class<? extends Matrix>> typesMap = typesMap(); /** */ private static final List<Class<? extends Vector>> likeMatrixUnsupported = Arrays.asList(FunctionVector.class, SingleElementVector.class, SingleElementVectorView.class, ConstantVector.class); /** */ @Test public void testHaveLikeMatrix() throws InstantiationException, IllegalAccessException { for (Class<? extends Vector> key : typesMap.keySet()) { Class<? extends Matrix> val = typesMap.get(key); if (val == null && likeMatrixSupported(key)) System.out.println("Missing test for implementation of likeMatrix for " + key.getSimpleName()); } } /** */ @Test public void testLikeMatrixUnsupported() throws Exception { consumeSampleVectors((v, desc) -> { if (likeMatrixSupported(v.getClass())) return; boolean expECaught = false; try { assertNull("Null view instead of exception in " + desc, v.likeMatrix(1, 1)); } catch (UnsupportedOperationException uoe) { expECaught = true; } assertTrue("Expected exception was not caught in " + desc, expECaught); }); } /** */ @Test public void testLikeMatrix() { consumeSampleVectors((v, desc) -> { if (!availableForTesting(v)) return; final Matrix matrix = v.likeMatrix(1, 1); Class<? extends Vector> key = v.getClass(); Class<? extends Matrix> expMatrixType = typesMap.get(key); assertNotNull("Expect non-null matrix for " + key.getSimpleName() + " in " + desc, matrix); Class<? extends Matrix> actualMatrixType = matrix.getClass(); assertTrue("Expected matrix type " + expMatrixType.getSimpleName() + " should be assignable from actual type " + actualMatrixType.getSimpleName() + " in " + desc, expMatrixType.isAssignableFrom(actualMatrixType)); for (int rows : new int[] {1, 2}) for (int cols : new int[] {1, 2}) { final Matrix actualMatrix = v.likeMatrix(rows, cols); String details = "rows " + rows + " cols " + cols; assertNotNull("Expect non-null matrix for " + details + " in " + desc, actualMatrix); assertEquals("Unexpected number of rows in " + desc, rows, actualMatrix.rowSize()); assertEquals("Unexpected number of cols in " + desc, cols, actualMatrix.columnSize()); } }); } /** */ @Test public void testToMatrix() { consumeSampleVectors((v, desc) -> { if (!availableForTesting(v)) return; fillWithNonZeroes(v); final Matrix matrixRow = v.toMatrix(true); final Matrix matrixCol = v.toMatrix(false); for (Vector.Element e : v.all()) assertToMatrixValue(desc, matrixRow, matrixCol, e.get(), e.index()); }); } /** */ @Test public void testToMatrixPlusOne() { consumeSampleVectors((v, desc) -> { if (!availableForTesting(v)) return; fillWithNonZeroes(v); for (double zeroVal : new double[] {-1, 0, 1, 2}) { final Matrix matrixRow = v.toMatrixPlusOne(true, zeroVal); final Matrix matrixCol = v.toMatrixPlusOne(false, zeroVal); final Metric metricRow0 = new Metric(zeroVal, matrixRow.get(0, 0)); assertTrue("Not close enough row like " + metricRow0 + " at index 0 in " + desc, metricRow0.closeEnough()); final Metric metricCol0 = new Metric(zeroVal, matrixCol.get(0, 0)); assertTrue("Not close enough cols like " + metricCol0 + " at index 0 in " + desc, metricCol0.closeEnough()); for (Vector.Element e : v.all()) assertToMatrixValue(desc, matrixRow, matrixCol, e.get(), e.index() + 1); } }); } /** */ @Test public void testCross() { consumeSampleVectors((v, desc) -> { if (!availableForTesting(v)) return; fillWithNonZeroes(v); for (int delta : new int[] {-1, 0, 1}) { final int size2 = v.size() + delta; if (size2 < 1) return; final Vector v2 = new DenseLocalOnHeapVector(size2); for (Vector.Element e : v2.all()) e.set(size2 - e.index()); assertCross(v, v2, desc); } }); } /** */ private void assertCross(Vector v1, Vector v2, String desc) { assertNotNull(v1); assertNotNull(v2); final Matrix res = v1.cross(v2); assertNotNull("Cross matrix is expected to be not null in " + desc, res); assertEquals("Unexpected number of rows in cross Matrix in " + desc, v1.size(), res.rowSize()); assertEquals("Unexpected number of cols in cross Matrix in " + desc, v2.size(), res.columnSize()); for (int row = 0; row < v1.size(); row++) for (int col = 0; col < v2.size(); col++) { final Metric metric = new Metric(v1.get(row) * v2.get(col), res.get(row, col)); assertTrue("Not close enough cross " + metric + " at row " + row + " at col " + col + " in " + desc, metric.closeEnough()); } } /** */ private void assertToMatrixValue(String desc, Matrix matrixRow, Matrix matrixCol, double exp, int idx) { final Metric metricRow = new Metric(exp, matrixRow.get(0, idx)); assertTrue("Not close enough row like " + metricRow + " at index " + idx + " in " + desc, metricRow.closeEnough()); final Metric metricCol = new Metric(exp, matrixCol.get(idx, 0)); assertTrue("Not close enough cols like " + matrixCol + " at index " + idx + " in " + desc, metricCol.closeEnough()); } /** */ private void fillWithNonZeroes(Vector sample) { if (sample instanceof RandomVector) return; for (Vector.Element e : sample.all()) e.set(1 + e.index()); } /** */ private boolean availableForTesting(Vector v) { assertNotNull("Error in test: vector is null", v); if (!likeMatrixSupported(v.getClass())) return false; final boolean availableForTesting = typesMap.get(v.getClass()) != null; final Matrix actualLikeMatrix = v.likeMatrix(1, 1); assertTrue("Need to enable matrix testing for vector type " + v.getClass().getSimpleName(), availableForTesting || actualLikeMatrix == null); return availableForTesting; } /** Ignore test for given vector type. */ private boolean likeMatrixSupported(Class<? extends Vector> clazz) { for (Class<? extends Vector> ignoredClass : likeMatrixUnsupported) if (ignoredClass.isAssignableFrom(clazz)) return false; return true; } /** */ private void consumeSampleVectors(BiConsumer<Vector, String> consumer) { new VectorImplementationsFixtures().consumeSampleVectors(null, consumer); } /** */ private static Map<Class<? extends Vector>, Class<? extends Matrix>> typesMap() { return new LinkedHashMap<Class<? extends Vector>, Class<? extends Matrix>>() {{ put(DenseLocalOnHeapVector.class, DenseLocalOnHeapMatrix.class); put(DenseLocalOffHeapVector.class, DenseLocalOffHeapMatrix.class); put(RandomVector.class, RandomMatrix.class); put(SparseLocalVector.class, SparseLocalOnHeapMatrix.class); put(SingleElementVector.class, null); // todo find out if we need SingleElementMatrix to match, or skip it put(ConstantVector.class, null); put(FunctionVector.class, null); put(PivotedVectorView.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture put(SingleElementVectorView.class, null); put(MatrixVectorView.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture put(DelegatingVector.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture // IMPL NOTE check for presence of all implementations here will be done in testHaveLikeMatrix via Fixture }}; } /** */ private static class Metric { // todo consider if softer tolerance (like say 0.1 or 0.01) would make sense here /** */ private final double exp; /** */ private final double obtained; /** **/ Metric(double exp, double obtained) { this.exp = exp; this.obtained = obtained; } /** */ boolean closeEnough() { return new Double(exp).equals(obtained); } /** {@inheritDoc} */ @Override public String toString() { return "Metric{" + "expected=" + exp + ", obtained=" + obtained + '}'; } } }