/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.vector;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Spliterator;
import java.util.function.BiConsumer;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.impls.MathTestConstants;
import org.junit.Test;
import static java.util.Spliterator.ORDERED;
import static java.util.Spliterator.SIZED;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/** */
public class VectorIterableTest {
/** */
@Test
public void allTest() {
consumeSampleVectors(
(v, desc) -> {
int expIdx = 0;
for (Vector.Element e : v.all()) {
int actualIdx = e.index();
assertEquals("Unexpected index for " + desc,
expIdx, actualIdx);
expIdx++;
}
assertEquals("Unexpected amount of elements for " + desc,
expIdx, v.size());
}
);
}
/** */
@Test
public void allTestBound() {
consumeSampleVectors(
(v, desc) -> iteratorTestBound(v.all().iterator(), desc)
);
}
/** */
@Test
public void nonZeroesTestBasic() {
final int size = 5;
final double[] nonZeroesOddData = new double[size], nonZeroesEvenData = new double[size];
for (int idx = 0; idx < size; idx++) {
final boolean odd = (idx & 1) == 1;
nonZeroesOddData[idx] = odd ? 1 : 0;
nonZeroesEvenData[idx] = odd ? 0 : 1;
}
assertTrue("Arrays failed to initialize.",
!isZero(nonZeroesEvenData[0])
&& isZero(nonZeroesEvenData[1])
&& isZero(nonZeroesOddData[0])
&& !isZero(nonZeroesOddData[1]));
final Vector nonZeroesEvenVec = new DenseLocalOnHeapVector(nonZeroesEvenData),
nonZeroesOddVec = new DenseLocalOnHeapVector(nonZeroesOddData);
assertTrue("Vectors failed to initialize.",
!isZero(nonZeroesEvenVec.getElement(0).get())
&& isZero(nonZeroesEvenVec.getElement(1).get())
&& isZero(nonZeroesOddVec.getElement(0).get())
&& !isZero(nonZeroesOddVec.getElement(1).get()));
assertTrue("Iterator(s) failed to start.",
nonZeroesEvenVec.nonZeroes().iterator().next() != null
&& nonZeroesOddVec.nonZeroes().iterator().next() != null);
int nonZeroesActual = 0;
for (Vector.Element e : nonZeroesEvenVec.nonZeroes()) {
final int idx = e.index();
final boolean odd = (idx & 1) == 1;
final double val = e.get();
assertTrue("Not an even index " + idx + ", for value " + val, !odd);
assertTrue("Zero value " + val + " at even index " + idx, !isZero(val));
nonZeroesActual++;
}
final int nonZeroesOddExp = (size + 1) / 2;
assertEquals("Unexpected num of iterated odd non-zeroes.", nonZeroesOddExp, nonZeroesActual);
assertEquals("Unexpected nonZeroElements of odd.", nonZeroesOddExp, nonZeroesEvenVec.nonZeroElements());
nonZeroesActual = 0;
for (Vector.Element e : nonZeroesOddVec.nonZeroes()) {
final int idx = e.index();
final boolean odd = (idx & 1) == 1;
final double val = e.get();
assertTrue("Not an odd index " + idx + ", for value " + val, odd);
assertTrue("Zero value " + val + " at even index " + idx, !isZero(val));
nonZeroesActual++;
}
final int nonZeroesEvenExp = size / 2;
assertEquals("Unexpected num of iterated even non-zeroes", nonZeroesEvenExp, nonZeroesActual);
assertEquals("Unexpected nonZeroElements of even", nonZeroesEvenExp, nonZeroesOddVec.nonZeroElements());
}
/** */
@Test
public void nonZeroesTest() {
// todo make RandomVector constructor that accepts a function and use it here
// in order to *reliably* test non-zeroes in there
consumeSampleVectors(
(v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes)
-> {
int numZeroesActual = vec.size();
for (Vector.Element e : vec.nonZeroes()) {
numZeroesActual--;
assertTrue("Unexpected zero at " + desc + ", index " + e.index(), !isZero(e.get()));
}
assertEquals("Unexpected num zeroes at " + desc, (int)numZeroes, numZeroesActual);
}));
}
/** */
@Test
public void nonZeroesTestBound() {
consumeSampleVectors(
(v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes)
-> iteratorTestBound(vec.nonZeroes().iterator(), desc)));
}
/** */
@Test
public void nonZeroElementsTest() {
consumeSampleVectors(
(v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes)
-> assertEquals("Unexpected num zeroes at " + desc,
(int)numZeroes, vec.size() - vec.nonZeroElements())));
}
/** */
@Test
public void allSpliteratorTest() {
consumeSampleVectors(
(v, desc) -> {
final String desc1 = " " + desc;
Spliterator<Double> spliterator = v.allSpliterator();
assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator);
assertNull(MathTestConstants.NOT_NULL_VAL + desc1, spliterator.trySplit());
assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED));
if (!readOnly(v))
fillWithNonZeroes(v);
spliterator = v.allSpliterator();
assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator);
assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.estimateSize(), v.size());
assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.getExactSizeIfKnown(), v.size());
assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED));
Spliterator<Double> secondHalf = spliterator.trySplit();
assertNull(MathTestConstants.NOT_NULL_VAL + desc1, secondHalf);
spliterator.tryAdvance(x -> {
});
}
);
}
/** */
@Test
public void nonZeroSpliteratorTest() {
consumeSampleVectors(
(v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes)
-> {
final String desc1 = " Num zeroes " + numZeroes + " " + desc;
Spliterator<Double> spliterator = vec.nonZeroSpliterator();
assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator);
assertNull(MathTestConstants.NOT_NULL_VAL + desc1, spliterator.trySplit());
assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED));
spliterator = vec.nonZeroSpliterator();
assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator);
assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.estimateSize(), vec.size() - numZeroes);
assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.getExactSizeIfKnown(), vec.size() - numZeroes);
assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED));
Spliterator<Double> secondHalf = spliterator.trySplit();
assertNull(MathTestConstants.NOT_NULL_VAL + desc1, secondHalf);
double[] data = new double[vec.size()];
for (Vector.Element e : vec.all())
data[e.index()] = e.get();
spliterator = vec.nonZeroSpliterator();
assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator);
assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.estimateSize(),
Arrays.stream(data).filter(x -> x != 0d).count());
assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.getExactSizeIfKnown(),
Arrays.stream(data).filter(x -> x != 0d).count());
assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED));
secondHalf = spliterator.trySplit();
assertNull(MathTestConstants.NOT_NULL_VAL + desc1, secondHalf);
if (!spliterator.tryAdvance(x -> {
}))
fail(MathTestConstants.NO_NEXT_ELEMENT + desc1);
}));
}
/** */
private void iteratorTestBound(Iterator<Vector.Element> it, String desc) {
while (it.hasNext())
assertNotNull(it.next());
boolean expECaught = false;
try {
it.next();
}
catch (NoSuchElementException e) {
expECaught = true;
}
assertTrue("Expected exception missed for " + desc,
expECaught);
}
/** */
private void consumeSampleVectorsWithZeroes(Vector sample,
BiConsumer<Vector, Integer> consumer) {
if (readOnly(sample)) {
int numZeroes = 0;
for (Vector.Element e : sample.all())
if (isZero(e.get()))
numZeroes++;
consumer.accept(sample, numZeroes);
return;
}
fillWithNonZeroes(sample);
consumer.accept(sample, 0);
final int sampleSize = sample.size();
if (sampleSize == 0)
return;
for (Vector.Element e : sample.all())
e.set(0);
consumer.accept(sample, sampleSize);
fillWithNonZeroes(sample);
for (int testIdx : new int[] {0, sampleSize / 2, sampleSize - 1}) {
final Vector.Element e = sample.getElement(testIdx);
final double backup = e.get();
e.set(0);
consumer.accept(sample, 1);
e.set(backup);
}
if (sampleSize < 3)
return;
sample.getElement(sampleSize / 3).set(0);
sample.getElement((2 * sampleSize) / 3).set(0);
consumer.accept(sample, 2);
}
/** */
private void fillWithNonZeroes(Vector sample) {
int idx = 0;
for (Vector.Element e : sample.all())
e.set(1 + idx++);
assertEquals("Not all filled with non-zeroes", idx, sample.size());
}
/** */
private void consumeSampleVectors(BiConsumer<Vector, String> consumer) {
new VectorImplementationsFixtures().consumeSampleVectors(null, consumer);
}
/** */
private boolean isZero(double val) {
return val == 0.0;
}
/** */
private boolean readOnly(Vector v) {
return v instanceof RandomVector || v instanceof ConstantVector;
}
}