/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.math.impls.vector; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; import org.apache.ignite.ml.math.Vector; import org.junit.Test; import static org.junit.Assert.assertTrue; /** */ public class VectorNormTest { /** */ @Test public void normalizeTest() { normalizeTest(2, (val, len) -> val / len, Vector::normalize); } /** */ @Test public void normalizePowerTest() { for (double pow : new double[] {0, 0.5, 1, 2, 2.5, Double.POSITIVE_INFINITY}) normalizeTest(pow, (val, norm) -> val / norm, (v) -> v.normalize(pow)); } /** */ @Test public void logNormalizeTest() { normalizeTest(2, (val, len) -> Math.log1p(val) / (len * Math.log(2)), Vector::logNormalize); } /** */ @Test public void logNormalizePowerTest() { for (double pow : new double[] {1.1, 2, 2.5}) normalizeTest(pow, (val, norm) -> Math.log1p(val) / (norm * Math.log(pow)), (v) -> v.logNormalize(pow)); } /** */ @Test public void kNormTest() { for (double pow : new double[] {0, 0.5, 1, 2, 2.5, Double.POSITIVE_INFINITY}) toDoubleTest(pow, ref -> new Norm(ref, pow).calculate(), v -> v.kNorm(pow)); } /** */ @Test public void getLengthSquaredTest() { toDoubleTest(2.0, ref -> new Norm(ref, 2).sumPowers(), Vector::getLengthSquared); } /** */ @Test public void getDistanceSquaredTest() { consumeSampleVectors((v, desc) -> { new VectorImplementationsTest.ElementsChecker(v, desc); // IMPL NOTE this initialises vector final int size = v.size(); final Vector vOnHeap = new DenseLocalOnHeapVector(size); final Vector vOffHeap = new DenseLocalOffHeapVector(size); invertValues(v, vOnHeap); invertValues(v, vOffHeap); for (int idx = 0; idx < size; idx++) { final double exp = v.get(idx); final int idxMirror = size - 1 - idx; assertTrue("On heap vector difference at " + desc + ", idx " + idx, exp - vOnHeap.get(idxMirror) == 0); assertTrue("Off heap vector difference at " + desc + ", idx " + idx, exp - vOffHeap.get(idxMirror) == 0); } final double exp = vOnHeap.minus(v).getLengthSquared(); // IMPL NOTE this won't mutate vOnHeap final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, v.getDistanceSquared(vOnHeap)); assertTrue("On heap vector not close enough at " + desc + ", " + metric, metric.closeEnough()); final VectorImplementationsTest.Metric metric1 = new VectorImplementationsTest.Metric(exp, v.getDistanceSquared(vOffHeap)); assertTrue("Off heap vector not close enough at " + desc + ", " + metric1, metric1.closeEnough()); }); } /** */ @Test public void dotTest() { consumeSampleVectors((v, desc) -> { new VectorImplementationsTest.ElementsChecker(v, desc); // IMPL NOTE this initialises vector final int size = v.size(); final Vector v1 = new DenseLocalOnHeapVector(size); invertValues(v, v1); final double actual = v.dot(v1); double exp = 0; for (Vector.Element e : v.all()) exp += e.get() * v1.get(e.index()); final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, actual); assertTrue("Dot product not close enough at " + desc + ", " + metric, metric.closeEnough()); }); } /** */ private void invertValues(Vector src, Vector dst) { final int size = src.size(); for (Vector.Element e : src.all()) { final int idx = size - 1 - e.index(); final double val = e.get(); dst.set(idx, val); } } /** */ private void toDoubleTest(Double val, Function<double[], Double> calcRef, Function<Vector, Double> calcVec) { consumeSampleVectors((v, desc) -> { final int size = v.size(); final double[] ref = new double[size]; new VectorImplementationsTest.ElementsChecker(v, ref, desc); // IMPL NOTE this initialises vector and reference array final double exp = calcRef.apply(ref); final double obtained = calcVec.apply(v); final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, obtained); assertTrue("Not close enough at " + desc + (val == null ? "" : ", value " + val) + ", " + metric, metric.closeEnough()); }); } /** */ private void normalizeTest(double pow, BiFunction<Double, Double, Double> operation, Function<Vector, Vector> vecOperation) { consumeSampleVectors((v, desc) -> { final int size = v.size(); final double[] ref = new double[size]; final boolean nonNegative = pow != (int)pow; final VectorImplementationsTest.ElementsChecker checker = new VectorImplementationsTest.ElementsChecker(v, ref, desc + ", pow = " + pow, nonNegative); final double norm = new Norm(ref, pow).calculate(); for (int idx = 0; idx < size; idx++) ref[idx] = operation.apply(ref[idx], norm); checker.assertCloseEnough(vecOperation.apply(v), ref); }); } /** */ private void consumeSampleVectors(BiConsumer<Vector, String> consumer) { new VectorImplementationsFixtures().consumeSampleVectors(null, consumer); } /** */ private static class Norm { /** */ private final double[] arr; /** */ private final Double pow; /** */ Norm(double[] arr, double pow) { this.arr = arr; this.pow = pow; } /** */ double calculate() { if (pow.equals(0.0)) return countNonZeroes(); // IMPL NOTE this is beautiful if you think of it if (pow.equals(Double.POSITIVE_INFINITY)) return maxAbs(); return Math.pow(sumPowers(), 1 / pow); } /** */ double sumPowers() { if (pow.equals(0.0)) return countNonZeroes(); double norm = 0; for (double val : arr) norm += pow == 1 ? Math.abs(val) : Math.pow(val, pow); return norm; } /** */ private int countNonZeroes() { int cnt = 0; final Double zero = 0.0; for (double val : arr) if (!zero.equals(val)) cnt++; return cnt; } /** */ private double maxAbs() { double res = 0; for (double val : arr) { final double abs = Math.abs(val); if (abs > res) res = abs; } return res; } } }