/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.vector;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.ignite.IgniteException;
import org.apache.ignite.ml.math.ExternalizeTest;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/** See also: {@link AbstractVectorTest} and {@link VectorToMatrixTest}. */
public class VectorImplementationsTest { // todo split this to smaller cohesive test classes
/** */
@Test
public void vectorImplementationsFixturesTest() {
new VectorImplementationsFixtures().selfTest();
}
/** */
@Test
public void setGetTest() {
consumeSampleVectors((v, desc) -> mutateAtIdxTest(v, desc, (vec, idx, val) -> {
vec.set(idx, val);
return val;
}));
}
/** */
@Test
public void setXTest() {
consumeSampleVectors((v, desc) -> mutateAtIdxTest(v, desc, (vec, idx, val) -> {
vec.setX(idx, val);
return val;
}));
}
/** */
@Test
public void incrementTest() {
consumeSampleVectors((v, desc) -> mutateAtIdxTest(v, desc, (vec, idx, val) -> {
double old = vec.get(idx);
vec.increment(idx, val);
return old + val;
}));
}
/** */
@Test
public void incrementXTest() {
consumeSampleVectors((v, desc) -> mutateAtIdxTest(v, desc, (vec, idx, val) -> {
double old = vec.getX(idx);
vec.incrementX(idx, val);
return old + val;
}));
}
/** */
@Test
public void operateXOutOfBoundsTest() {
consumeSampleVectors((v, desc) -> {
if (v instanceof DenseLocalOffHeapVector || v instanceof SparseLocalVector || v instanceof SparseLocalOffHeapVector)
return; // todo find out if it's OK to skip by instances here
boolean expECaught = false;
try {
v.getX(-1);
}
catch (ArrayIndexOutOfBoundsException | IgniteException e) {
expECaught = true;
}
if (!getXOutOfBoundsOK(v))
assertTrue("Expect exception at negative index getX in " + desc, expECaught);
expECaught = false;
try {
v.setX(-1, 0);
}
catch (ArrayIndexOutOfBoundsException | IgniteException e) {
expECaught = true;
}
assertTrue("Expect exception at negative index setX in " + desc, expECaught);
expECaught = false;
try {
v.incrementX(-1, 1);
}
catch (ArrayIndexOutOfBoundsException | IgniteException e) {
expECaught = true;
}
assertTrue("Expect exception at negative index incrementX in " + desc, expECaught);
expECaught = false;
try {
v.getX(v.size());
}
catch (ArrayIndexOutOfBoundsException | IgniteException e) {
expECaught = true;
}
if (!getXOutOfBoundsOK(v))
assertTrue("Expect exception at too large index getX in " + desc, expECaught);
expECaught = false;
try {
v.setX(v.size(), 1);
}
catch (ArrayIndexOutOfBoundsException | IgniteException e) {
expECaught = true;
}
assertTrue("Expect exception at too large index setX in " + desc, expECaught);
expECaught = false;
try {
v.incrementX(v.size(), 1);
}
catch (ArrayIndexOutOfBoundsException | IgniteException e) {
expECaught = true;
}
assertTrue("Expect exception at too large index incrementX in " + desc, expECaught);
});
}
/** */
@Test
public void sizeTest() {
final AtomicReference<Integer> expSize = new AtomicReference<>(0);
consumeSampleVectors(
expSize::set,
(v, desc) -> Assert.assertEquals("Expected size for " + desc,
(int)expSize.get(), v.size())
);
}
/** */
@Test
public void getElementTest() {
consumeSampleVectors((v, desc) -> new ElementsChecker(v, desc).assertCloseEnough(v));
}
/** */
@Test
public void copyTest() {
consumeSampleVectors((v, desc) -> new ElementsChecker(v, desc).assertCloseEnough(v.copy()));
}
/** */
@Test
public void divideTest() {
operationTest((val, operand) -> val / operand, Vector::divide);
}
/** */
@Test
public void likeTest() {
for (int card : new int[] {1, 2, 4, 8, 16, 32, 64, 128})
consumeSampleVectors((v, desc) -> {
Class<? extends Vector> expType = expLikeType(v);
if (expType == null) {
try {
v.like(card);
}
catch (UnsupportedOperationException uoe) {
return;
}
fail("Expected exception wasn't caught for " + desc);
return;
}
Vector vLike = v.like(card);
assertNotNull("Expect non-null like vector for " + expType.getSimpleName() + " in " + desc, vLike);
assertEquals("Expect size equal to cardinality at " + desc, card, vLike.size());
Class<? extends Vector> actualType = vLike.getClass();
assertTrue("Actual vector type " + actualType.getSimpleName()
+ " should be assignable from expected type " + expType.getSimpleName() + " in " + desc,
actualType.isAssignableFrom(expType));
});
}
/** */
@Test
public void minusTest() {
operationVectorTest((operand1, operand2) -> operand1 - operand2, Vector::minus);
}
/** */
@Test
public void plusVectorTest() {
operationVectorTest((operand1, operand2) -> operand1 + operand2, Vector::plus);
}
/** */
@Test
public void plusDoubleTest() {
operationTest((val, operand) -> val + operand, Vector::plus);
}
/** */
@Test
public void timesVectorTest() {
operationVectorTest((operand1, operand2) -> operand1 * operand2, Vector::times);
}
/** */
@Test
public void timesDoubleTest() {
operationTest((val, operand) -> val * operand, Vector::times);
}
/** */
@Test
public void viewPartTest() {
consumeSampleVectors((v, desc) -> {
final int size = v.size();
final double[] ref = new double[size];
final int delta = size > 32 ? 3 : 1; // IMPL NOTE this is for faster test execution
final ElementsChecker checker = new ElementsChecker(v, ref, desc);
for (int off = 0; off < size; off += delta)
for (int len = 1; len < size - off; len += delta)
checker.assertCloseEnough(v.viewPart(off, len), Arrays.copyOfRange(ref, off, off + len));
});
}
/** */
@Test
public void sumTest() {
toDoubleTest(
ref -> Arrays.stream(ref).sum(),
Vector::sum);
}
/** */
@Test
public void minValueTest() {
toDoubleTest(
ref -> Arrays.stream(ref).min().getAsDouble(),
Vector::minValue);
}
/** */
@Test
public void maxValueTest() {
toDoubleTest(
ref -> Arrays.stream(ref).max().getAsDouble(),
Vector::maxValue);
}
/** */
@Test
public void sortTest() {
consumeSampleVectors((v, desc) -> {
if (readOnly(v) || !v.isArrayBased()) {
boolean expECaught = false;
try {
v.sort();
}
catch (UnsupportedOperationException uoe) {
expECaught = true;
}
assertTrue("Expected exception was not caught for sort in " + desc, expECaught);
return;
}
final int size = v.size();
final double[] ref = new double[size];
new ElementsChecker(v, ref, desc).assertCloseEnough(v.sort(), Arrays.stream(ref).sorted().toArray());
});
}
/** */
@Test
public void metaAttributesTest() {
consumeSampleVectors((v, desc) -> {
assertNotNull("Null meta storage in " + desc, v.getMetaStorage());
final String key = "test key";
final String val = "test value";
final String details = "key [" + key + "] for " + desc;
v.setAttribute(key, val);
assertTrue("Expect to have meta attribute for " + details, v.hasAttribute(key));
assertEquals("Unexpected meta attribute value for " + details, val, v.getAttribute(key));
v.removeAttribute(key);
assertFalse("Expect not to have meta attribute for " + details, v.hasAttribute(key));
assertNull("Unexpected meta attribute value for " + details, v.getAttribute(key));
});
}
/** */
@Test
public void assignDoubleTest() {
consumeSampleVectors((v, desc) -> {
if (readOnly(v))
return;
for (double val : new double[] {0, -1, 0, 1}) {
v.assign(val);
for (int idx = 0; idx < v.size(); idx++) {
final Metric metric = new Metric(val, v.get(idx));
assertTrue("Not close enough at index " + idx + ", val " + val + ", " + metric
+ ", " + desc, metric.closeEnough());
}
}
});
}
/** */
@Test
public void assignDoubleArrTest() {
consumeSampleVectors((v, desc) -> {
if (readOnly(v))
return;
final int size = v.size();
final double[] ref = new double[size];
final ElementsChecker checker = new ElementsChecker(v, ref, desc);
for (int idx = 0; idx < size; idx++)
ref[idx] = -ref[idx];
v.assign(ref);
checker.assertCloseEnough(v, ref);
assignDoubleArrWrongCardinality(v, desc);
});
}
/** */
@Test
public void assignVectorTest() {
consumeSampleVectors((v, desc) -> {
if (readOnly(v))
return;
final int size = v.size();
final double[] ref = new double[size];
final ElementsChecker checker = new ElementsChecker(v, ref, desc);
for (int idx = 0; idx < size; idx++)
ref[idx] = -ref[idx];
v.assign(new DenseLocalOnHeapVector(ref));
checker.assertCloseEnough(v, ref);
assignVectorWrongCardinality(v, desc);
});
}
/** */
@Test
public void assignFunctionTest() {
consumeSampleVectors((v, desc) -> {
if (readOnly(v))
return;
final int size = v.size();
final double[] ref = new double[size];
final ElementsChecker checker = new ElementsChecker(v, ref, desc);
for (int idx = 0; idx < size; idx++)
ref[idx] = -ref[idx];
v.assign((idx) -> ref[idx]);
checker.assertCloseEnough(v, ref);
});
}
/** */
@Test
public void minElementTest() {
consumeSampleVectors((v, desc) -> {
final ElementsChecker checker = new ElementsChecker(v, desc);
final Vector.Element minE = v.minElement();
final int minEIdx = minE.index();
assertTrue("Unexpected index from minElement " + minEIdx + ", " + desc,
minEIdx >= 0 && minEIdx < v.size());
final Metric metric = new Metric(minE.get(), v.minValue());
assertTrue("Not close enough minElement at index " + minEIdx + ", " + metric
+ ", " + desc, metric.closeEnough());
checker.assertNewMinElement(v);
});
}
/** */
@Test
public void maxElementTest() {
consumeSampleVectors((v, desc) -> {
final ElementsChecker checker = new ElementsChecker(v, desc);
final Vector.Element maxE = v.maxElement();
final int minEIdx = maxE.index();
assertTrue("Unexpected index from minElement " + minEIdx + ", " + desc,
minEIdx >= 0 && minEIdx < v.size());
final Metric metric = new Metric(maxE.get(), v.maxValue());
assertTrue("Not close enough maxElement at index " + minEIdx + ", " + metric
+ ", " + desc, metric.closeEnough());
checker.assertNewMaxElement(v);
});
}
/** */
@Test
public void externalizeTest() {
(new ExternalizeTest<Vector>() {
/** {@inheritDoc} */
@Override public void externalizeTest() {
consumeSampleVectors((v, desc) -> {
if (v instanceof SparseLocalOffHeapVector)
return; //TODO: wait till SparseLocalOffHeapVector externalization support.
externalizeTest(v);
});
}
}).externalizeTest();
}
/** */
@Test
public void hashCodeTest() {
consumeSampleVectors((v, desc) -> assertTrue("Zero hash code for " + desc, v.hashCode() != 0));
}
/** */
private boolean getXOutOfBoundsOK(Vector v) {
// todo find out if this is indeed OK
return v instanceof RandomVector || v instanceof ConstantVector
|| v instanceof SingleElementVector || v instanceof SingleElementVectorView;
}
/** */
private void mutateAtIdxTest(Vector v, String desc, MutateAtIdx operation) {
if (readOnly(v)) {
if (v.size() < 1)
return;
boolean expECaught = false;
try {
operation.apply(v, 0, 1);
}
catch (UnsupportedOperationException uoe) {
expECaught = true;
}
assertTrue("Expect exception at attempt to mutate element in " + desc, expECaught);
return;
}
for (double val : new double[] {0, -1, 0, 1})
for (int idx = 0; idx < v.size(); idx++) {
double exp = operation.apply(v, idx, val);
final Metric metric = new Metric(exp, v.get(idx));
assertTrue("Not close enough at index " + idx + ", val " + val + ", " + metric
+ ", " + desc, metric.closeEnough());
}
}
/** */
private Class<? extends Vector> expLikeType(Vector v) {
Class<? extends Vector> clazz = v.getClass();
if (clazz.isAssignableFrom(PivotedVectorView.class) || clazz.isAssignableFrom(SingleElementVectorView.class))
return null;
if (clazz.isAssignableFrom(MatrixVectorView.class) || clazz.isAssignableFrom(DelegatingVector.class))
return DenseLocalOnHeapVector.class; // IMPL NOTE per fixture
return clazz;
}
/** */
private void toDoubleTest(Function<double[], Double> calcRef, Function<Vector, Double> calcVec) {
consumeSampleVectors((v, desc) -> {
final int size = v.size();
final double[] ref = new double[size];
new ElementsChecker(v, ref, desc); // IMPL NOTE this initialises vector and reference array
final Metric metric = new Metric(calcRef.apply(ref), calcVec.apply(v));
assertTrue("Not close enough at " + desc
+ ", " + metric, metric.closeEnough());
});
}
/** */
private void operationVectorTest(BiFunction<Double, Double, Double> operation,
BiFunction<Vector, Vector, Vector> vecOperation) {
consumeSampleVectors((v, desc) -> {
// TODO find out if more elaborate testing scenario is needed or it's okay as is.
final int size = v.size();
final double[] ref = new double[size];
final ElementsChecker checker = new ElementsChecker(v, ref, desc);
final Vector operand = v.copy();
for (int idx = 0; idx < size; idx++)
ref[idx] = operation.apply(ref[idx], ref[idx]);
checker.assertCloseEnough(vecOperation.apply(v, operand), ref);
assertWrongCardinality(v, desc, vecOperation);
});
}
/** */
private void assignDoubleArrWrongCardinality(Vector v, String desc) {
boolean expECaught = false;
try {
v.assign(new double[v.size() + 1]);
}
catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too large size in " + desc, expECaught);
if (v.size() < 2)
return;
expECaught = false;
try {
v.assign(new double[v.size() - 1]);
}
catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too small size in " + desc, expECaught);
}
/** */
private void assignVectorWrongCardinality(Vector v, String desc) {
boolean expECaught = false;
try {
v.assign(new DenseLocalOnHeapVector(v.size() + 1));
}
catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too large size in " + desc, expECaught);
if (v.size() < 2)
return;
expECaught = false;
try {
v.assign(new DenseLocalOnHeapVector(v.size() - 1));
}
catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too small size in " + desc, expECaught);
}
/** */
private void assertWrongCardinality(
Vector v, String desc, BiFunction<Vector, Vector, Vector> vecOperation) {
boolean expECaught = false;
try {
vecOperation.apply(v, new DenseLocalOnHeapVector(v.size() + 1));
}
catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too large size in " + desc, expECaught);
if (v.size() < 2)
return;
expECaught = false;
try {
vecOperation.apply(v, new DenseLocalOnHeapVector(v.size() - 1));
}
catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too small size in " + desc, expECaught);
}
/** */
private void operationTest(BiFunction<Double, Double, Double> operation,
BiFunction<Vector, Double, Vector> vecOperation) {
for (double val : new double[] {0, 0.1, 1, 2, 10})
consumeSampleVectors((v, desc) -> {
final int size = v.size();
final double[] ref = new double[size];
final ElementsChecker checker = new ElementsChecker(v, ref, "val " + val + ", " + desc);
for (int idx = 0; idx < size; idx++)
ref[idx] = operation.apply(ref[idx], val);
checker.assertCloseEnough(vecOperation.apply(v, val), ref);
});
}
/** */
private void consumeSampleVectors(BiConsumer<Vector, String> consumer) {
consumeSampleVectors(null, consumer);
}
/** */
private void consumeSampleVectors(Consumer<Integer> paramsConsumer, BiConsumer<Vector, String> consumer) {
new VectorImplementationsFixtures().consumeSampleVectors(paramsConsumer, consumer);
}
/** */
private static boolean readOnly(Vector v) {
return v instanceof RandomVector || v instanceof ConstantVector;
}
/** */
private interface MutateAtIdx {
/** */
double apply(Vector v, int idx, double val);
}
/** */
static class ElementsChecker {
/** */
private final String fixtureDesc;
/** */
private final double[] refReadOnly;
/** */
private final boolean nonNegative;
/** */
ElementsChecker(Vector v, double[] ref, String fixtureDesc, boolean nonNegative) {
this.fixtureDesc = fixtureDesc;
this.nonNegative = nonNegative;
refReadOnly = readOnly(v) && ref == null ? new double[v.size()] : null;
init(v, ref);
}
/** */
ElementsChecker(Vector v, double[] ref, String fixtureDesc) {
this(v, ref, fixtureDesc, false);
}
/** */
ElementsChecker(Vector v, String fixtureDesc) {
this(v, null, fixtureDesc);
}
/** */
void assertCloseEnough(Vector obtained, double[] exp) {
final int size = obtained.size();
for (int i = 0; i < size; i++) {
final Vector.Element e = obtained.getElement(i);
if (refReadOnly != null && exp == null)
exp = refReadOnly;
final Metric metric = new Metric(exp == null ? generated(i) : exp[i], e.get());
assertEquals("Unexpected vector index at " + fixtureDesc, i, e.index());
assertTrue("Not close enough at index " + i + ", size " + size + ", " + metric
+ ", " + fixtureDesc, metric.closeEnough());
}
}
/** */
void assertCloseEnough(Vector obtained) {
assertCloseEnough(obtained, null);
}
/** */
void assertNewMinElement(Vector v) {
if (readOnly(v))
return;
int exp = v.size() / 2;
v.set(exp, -(v.size() * 2 + 1));
assertEquals("Unexpected minElement index at " + fixtureDesc, exp, v.minElement().index());
}
/** */
void assertNewMaxElement(Vector v) {
if (readOnly(v))
return;
int exp = v.size() / 2;
v.set(exp, v.size() * 2 + 1);
assertEquals("Unexpected minElement index at " + fixtureDesc, exp, v.maxElement().index());
}
/** */
private void init(Vector v, double[] ref) {
if (readOnly(v)) {
initReadonly(v, ref);
return;
}
for (Vector.Element e : v.all()) {
int idx = e.index();
// IMPL NOTE introduce negative values because their absence
// blocked catching an ugly bug in AbstractVector#kNorm
int val = generated(idx);
e.set(val);
if (ref != null)
ref[idx] = val;
}
}
/** */
private void initReadonly(Vector v, double[] ref) {
if (refReadOnly != null)
for (Vector.Element e : v.all())
refReadOnly[e.index()] = e.get();
if (ref != null)
for (Vector.Element e : v.all())
ref[e.index()] = e.get();
}
/** */
private int generated(int idx) {
return nonNegative || (idx & 1) == 0 ? idx : -idx;
}
}
/** */
static class Metric { // todo consider if softer tolerance (like say 0.1 or 0.01) would make sense here
/** */
private final double exp;
/** */
private final double obtained;
/** **/
Metric(double exp, double obtained) {
this.exp = exp;
this.obtained = obtained;
}
/** */
boolean closeEnough() {
return new Double(exp).equals(obtained) || closeEnoughToZero();
}
/** {@inheritDoc} */
@Override public String toString() {
return "Metric{" + "expected=" + exp +
", obtained=" + obtained +
'}';
}
/** */
private boolean closeEnoughToZero() {
return (new Double(exp).equals(0.0) && new Double(obtained).equals(-0.0))
|| (new Double(exp).equals(-0.0) && new Double(obtained).equals(0.0));
}
}
}