/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml; import java.util.stream.IntStream; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Precision; import org.apache.ignite.ml.math.Vector; import org.junit.Assert; /** */ public class TestUtils { /** * Collection of static methods used in math unit tests. */ private TestUtils() { super(); } /** * Verifies that expected and actual are within delta, or are both NaN or * infinities of the same sign. * * @param exp Expected value. * @param actual Actual value. * @param delta Maximum allowed delta between {@code exp} and {@code actual}. */ public static void assertEquals(double exp, double actual, double delta) { Assert.assertEquals(null, exp, actual, delta); } /** * Verifies that expected and actual are within delta, or are both NaN or * infinities of the same sign. */ public static void assertEquals(String msg, double exp, double actual, double delta) { // Check for NaN. if (Double.isNaN(exp)) Assert.assertTrue("" + actual + " is not NaN.", Double.isNaN(actual)); else Assert.assertEquals(msg, exp, actual, delta); } /** * Verifies that two double arrays have equal entries, up to tolerance. */ public static void assertEquals(double exp[], double observed[], double tolerance) { assertEquals("Array comparison failure", exp, observed, tolerance); } /** * Asserts that all entries of the specified vectors are equal to within a * positive {@code delta}. * * @param msg The identifying message for the assertion error (can be {@code null}). * @param exp Expected value. * @param actual Actual value. * @param delta The maximum difference between the entries of the expected and actual vectors for which both entries * are still considered equal. */ public static void assertEquals(final String msg, final double[] exp, final Vector actual, final double delta) { final String msgAndSep = msg.equals("") ? "" : msg + ", "; Assert.assertEquals(msgAndSep + "dimension", exp.length, actual.size()); for (int i = 0; i < exp.length; i++) Assert.assertEquals(msgAndSep + "entry #" + i, exp[i], actual.getX(i), delta); } /** * Asserts that all entries of the specified vectors are equal to within a * positive {@code delta}. * * @param msg The identifying message for the assertion error (can be {@code null}). * @param exp Expected value. * @param actual Actual value. * @param delta The maximum difference between the entries of the expected and actual vectors for which both entries * are still considered equal. */ public static void assertEquals(final String msg, final Vector exp, final Vector actual, final double delta) { final String msgAndSep = msg.equals("") ? "" : msg + ", "; Assert.assertEquals(msgAndSep + "dimension", exp.size(), actual.size()); final int dim = exp.size(); for (int i = 0; i < dim; i++) Assert.assertEquals(msgAndSep + "entry #" + i, exp.getX(i), actual.getX(i), delta); } /** * Verifies that two matrices are close (1-norm). * * @param msg The identifying message for the assertion error. * @param exp Expected matrix. * @param actual Actual matrix. * @param tolerance Comparison tolerance value. */ public static void assertEquals(String msg, Matrix exp, Matrix actual, double tolerance) { Assert.assertNotNull(msg + "\nObserved should not be null", actual); if (exp.columnSize() != actual.columnSize() || exp.rowSize() != actual.rowSize()) { String msgBuff = msg + "\nObserved has incorrect dimensions." + "\nobserved is " + actual.rowSize() + " x " + actual.columnSize() + "\nexpected " + exp.rowSize() + " x " + exp.columnSize(); Assert.fail(msgBuff); } Matrix delta = exp.minus(actual); if (TestUtils.maximumAbsoluteRowSum(delta) >= tolerance) { String msgBuff = msg + "\nExpected: " + exp + "\nObserved: " + actual + "\nexpected - observed: " + delta; Assert.fail(msgBuff); } } /** * Verifies that two matrices are equal. * * @param exp Expected matrix. * @param actual Actual matrix. */ public static void assertEquals(Matrix exp, Matrix actual) { Assert.assertNotNull("Observed should not be null", actual); if (exp.columnSize() != actual.columnSize() || exp.rowSize() != actual.rowSize()) { String msgBuff = "Observed has incorrect dimensions." + "\nobserved is " + actual.rowSize() + " x " + actual.columnSize() + "\nexpected " + exp.rowSize() + " x " + exp.columnSize(); Assert.fail(msgBuff); } for (int i = 0; i < exp.rowSize(); ++i) for (int j = 0; j < exp.columnSize(); ++j) { double eij = exp.getX(i, j); double aij = actual.getX(i, j); // TODO: Check precision here. Assert.assertEquals(eij, aij, 0.0); } } /** * Verifies that two double arrays are close (sup norm). * * @param msg The identifying message for the assertion error. * @param exp Expected array. * @param actual Actual array. * @param tolerance Comparison tolerance value. */ public static void assertEquals(String msg, double[] exp, double[] actual, double tolerance) { StringBuilder out = new StringBuilder(msg); if (exp.length != actual.length) { out.append("\n Arrays not same length. \n"); out.append("expected has length "); out.append(exp.length); out.append(" observed length = "); out.append(actual.length); Assert.fail(out.toString()); } boolean failure = false; for (int i = 0; i < exp.length; i++) if (!Precision.equalsIncludingNaN(exp[i], actual[i], tolerance)) { failure = true; out.append("\n Elements at index "); out.append(i); out.append(" differ. "); out.append(" expected = "); out.append(exp[i]); out.append(" observed = "); out.append(actual[i]); } if (failure) Assert.fail(out.toString()); } /** * Verifies that two float arrays are close (sup norm). * * @param msg The identifying message for the assertion error. * @param exp Expected array. * @param actual Actual array. * @param tolerance Comparison tolerance value. */ public static void assertEquals(String msg, float[] exp, float[] actual, float tolerance) { StringBuilder out = new StringBuilder(msg); if (exp.length != actual.length) { out.append("\n Arrays not same length. \n"); out.append("expected has length "); out.append(exp.length); out.append(" observed length = "); out.append(actual.length); Assert.fail(out.toString()); } boolean failure = false; for (int i = 0; i < exp.length; i++) if (!Precision.equalsIncludingNaN(exp[i], actual[i], tolerance)) { failure = true; out.append("\n Elements at index "); out.append(i); out.append(" differ. "); out.append(" expected = "); out.append(exp[i]); out.append(" observed = "); out.append(actual[i]); } if (failure) Assert.fail(out.toString()); } /** */ public static double maximumAbsoluteRowSum(Matrix mtx) { return IntStream.range(0, mtx.rowSize()).mapToObj(mtx::viewRow).map(v -> Math.abs(v.sum())).reduce(Math::max).get(); } }