/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.vector;
import java.util.Arrays;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.Functions;
import org.junit.Test;
import static java.util.function.DoubleUnaryOperator.identity;
import static org.junit.Assert.assertTrue;
/** See also: {@link AbstractVectorTest} and {@link VectorToMatrixTest}. */
public class VectorFoldMapTest {
/** */
@Test
public void mapVectorTest() {
operationVectorTest((operand1, operand2) -> operand1 + operand2, (Vector v1, Vector v2) -> v1.map(v2, Functions.PLUS));
}
/** */
@Test
public void mapDoubleFunctionTest() {
consumeSampleVectors((v, desc) -> operatorTest(v, desc,
(vec) -> vec.map(Functions.INV), (val) -> 1.0 / val));
}
/** */
@Test
public void mapBiFunctionTest() {
consumeSampleVectors((v, desc) -> operatorTest(v, desc,
(vec) -> vec.map(Functions.PLUS, 1.0), (val) -> 1.0 + val));
}
/** */
@Test
public void foldMapTest() {
toDoubleTest(
ref -> Arrays.stream(ref).map(identity()).sum(),
(v) -> v.foldMap(Functions.PLUS, Functions.IDENTITY, 0.0));
}
/** */
@Test
public void foldMapVectorTest() {
toDoubleTest(
ref -> 2.0 * Arrays.stream(ref).sum(),
(v) -> v.foldMap(v, Functions.PLUS, Functions.PLUS, 0.0));
}
/** */
private void operatorTest(Vector v, String desc, Function<Vector, Vector> op, Function<Double, Double> refOp) {
final int size = v.size();
final double[] ref = new double[size];
VectorImplementationsTest.ElementsChecker checker = new VectorImplementationsTest.ElementsChecker(v, ref, desc);
Vector actual = op.apply(v);
for (int idx = 0; idx < size; idx++)
ref[idx] = refOp.apply(ref[idx]);
checker.assertCloseEnough(actual, ref);
}
/** */
private void toDoubleTest(Function<double[], Double> calcRef, Function<Vector, Double> calcVec) {
consumeSampleVectors((v, desc) -> {
final int size = v.size();
final double[] ref = new double[size];
new VectorImplementationsTest.ElementsChecker(v, ref, desc); // IMPL NOTE this initialises vector and reference array
final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(calcRef.apply(ref), calcVec.apply(v));
assertTrue("Not close enough at " + desc
+ ", " + metric, metric.closeEnough());
});
}
/** */
private void operationVectorTest(BiFunction<Double, Double, Double> operation,
BiFunction<Vector, Vector, Vector> vecOperation) {
consumeSampleVectors((v, desc) -> {
// TODO find out if more elaborate testing scenario is needed or it's okay as is.
final int size = v.size();
final double[] ref = new double[size];
final VectorImplementationsTest.ElementsChecker checker = new VectorImplementationsTest.ElementsChecker(v, ref, desc);
final Vector operand = v.copy();
for (int idx = 0; idx < size; idx++)
ref[idx] = operation.apply(ref[idx], ref[idx]);
checker.assertCloseEnough(vecOperation.apply(v, operand), ref);
});
}
/** */
private void consumeSampleVectors(BiConsumer<Vector, String> consumer) {
new VectorImplementationsFixtures().consumeSampleVectors(null, consumer);
}
}