/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.math.impls.vector; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.storage.vector.FunctionVectorStorage; import org.jetbrains.annotations.NotNull; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; /** */ class VectorImplementationsFixtures { /** */ private static final List<Supplier<Iterable<Vector>>> suppliers = Arrays.asList( (Supplier<Iterable<Vector>>)DenseLocalOnHeapVectorFixture::new, (Supplier<Iterable<Vector>>)DenseLocalOffHeapVectorFixture::new, (Supplier<Iterable<Vector>>)SparseLocalVectorFixture::new, (Supplier<Iterable<Vector>>)RandomVectorFixture::new, (Supplier<Iterable<Vector>>)ConstantVectorFixture::new, (Supplier<Iterable<Vector>>)DelegatingVectorFixture::new, (Supplier<Iterable<Vector>>)FunctionVectorFixture::new, (Supplier<Iterable<Vector>>)SingleElementVectorFixture::new, (Supplier<Iterable<Vector>>)PivotedVectorViewFixture::new, (Supplier<Iterable<Vector>>)SingleElementVectorViewFixture::new, (Supplier<Iterable<Vector>>)MatrixVectorViewFixture::new, (Supplier<Iterable<Vector>>)SparseLocalOffHeapVectorFixture::new ); /** */ void consumeSampleVectors(Consumer<Integer> paramsConsumer, BiConsumer<Vector, String> consumer) { for (Supplier<Iterable<Vector>> fixtureSupplier : VectorImplementationsFixtures.suppliers) { final Iterable<Vector> fixture = fixtureSupplier.get(); for (Vector v : fixture) { if (paramsConsumer != null) paramsConsumer.accept(v.size()); consumer.accept(v, fixture.toString()); } } } /** */ void selfTest() { new VectorSizesExtraIterator<>("VectorSizesExtraIterator test", (size, shallowCp) -> new DenseLocalOnHeapVector(new double[size], shallowCp), null, "shallow copy", new Boolean[] {false, true, null}).selfTest(); new VectorSizesIterator("VectorSizesIterator test", DenseLocalOffHeapVector::new, null).selfTest(); } /** */ private static class DenseLocalOnHeapVectorFixture extends VectorSizesExtraFixture<Boolean> { /** */ DenseLocalOnHeapVectorFixture() { super("DenseLocalOnHeapVector", (size, shallowCp) -> new DenseLocalOnHeapVector(new double[size], shallowCp), "shallow copy", new Boolean[] {false, true, null}); } } /** */ private static class DenseLocalOffHeapVectorFixture extends VectorSizesFixture { /** */ DenseLocalOffHeapVectorFixture() { super("DenseLocalOffHeapVector", DenseLocalOffHeapVector::new); } } /** */ private static class SparseLocalVectorFixture extends VectorSizesExtraFixture<Integer> { /** */ SparseLocalVectorFixture() { super("SparseLocalVector", SparseLocalVector::new, "access mode", new Integer[] {StorageConstants.SEQUENTIAL_ACCESS_MODE, StorageConstants.RANDOM_ACCESS_MODE, null}); } } /** */ private static class RandomVectorFixture extends VectorSizesFixture { /** */ RandomVectorFixture() { super("RandomVector", RandomVector::new); } } /** */ private static class ConstantVectorFixture extends VectorSizesExtraFixture<Double> { /** */ ConstantVectorFixture() { super("ConstantVector", ConstantVector::new, "value", new Double[] {-1.0, 0.0, 0.5, 1.0, 2.0, null}); } } /** */ private static class FunctionVectorFixture extends VectorSizesExtraFixture<Double> { /** */ FunctionVectorFixture() { super("FunctionVector", (size, scale) -> new FunctionVectorForTest(new double[size], scale), "scale", new Double[] {0.5, 1.0, 2.0, null}); } } /** */ private static class SingleElementVectorFixture implements Iterable<Vector> { /** */ private final Supplier<TwoParamsIterator<Integer, Double>> iter; /** */ private final AtomicReference<String> ctxDescrHolder = new AtomicReference<>("Iterator not started."); /** */ SingleElementVectorFixture() { iter = () -> new TwoParamsIterator<Integer, Double>("SingleElementVector", null, ctxDescrHolder::set, "size", new Integer[] {1, null}, "value", new Double[] {-1.0, 0.0, 0.5, 1.0, 2.0, null}) { /** {@inheritDoc} */ @Override BiFunction<Integer, Double, Vector> ctor() { return (size, value) -> new SingleElementVector(size, 0, value); } }; } /** {@inheritDoc} */ @NotNull @Override public Iterator<Vector> iterator() { return iter.get();//( } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return ctxDescrHolder.get(); } } /** */ private static class PivotedVectorViewFixture extends VectorSizesFixture { /** */ PivotedVectorViewFixture() { super("PivotedVectorView", PivotedVectorViewFixture::pivotedVectorView); } /** */ private static PivotedVectorView pivotedVectorView(int size) { final DenseLocalOnHeapVector vec = new DenseLocalOnHeapVector(size); final int[] pivot = new int[size]; for (int idx = 0; idx < size; idx++) pivot[idx] = size - 1 - idx; PivotedVectorView tmp = new PivotedVectorView(vec, pivot); final int[] unpivot = new int[size]; for (int idx = 0; idx < size; idx++) unpivot[idx] = tmp.unpivot(idx); final int[] idxRecovery = new int[size]; for (int idx = 0; idx < size; idx++) idxRecovery[idx] = idx; return new PivotedVectorView(new PivotedVectorView(tmp, unpivot), idxRecovery); } } /** */ private static class SingleElementVectorViewFixture implements Iterable<Vector> { /** */ private final Supplier<TwoParamsIterator<Integer, Double>> iter; /** */ private final AtomicReference<String> ctxDescrHolder = new AtomicReference<>("Iterator not started."); /** */ SingleElementVectorViewFixture() { iter = () -> new TwoParamsIterator<Integer, Double>("SingleElementVectorView", null, ctxDescrHolder::set, "size", new Integer[] {1, null}, "value", new Double[] {-1.0, 0.0, 0.5, 1.0, 2.0, null}) { /** {@inheritDoc} */ @Override BiFunction<Integer, Double, Vector> ctor() { return (size, value) -> new SingleElementVectorView(new SingleElementVector(size, 0, value), 0); } }; } /** {@inheritDoc} */ @NotNull @Override public Iterator<Vector> iterator() { return iter.get(); } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return ctxDescrHolder.get(); } } /** */ private static class MatrixVectorViewFixture extends VectorSizesExtraFixture<Integer> { /** */ MatrixVectorViewFixture() { super("MatrixVectorView", MatrixVectorViewFixture::newView, "stride kind", new Integer[] {0, 1, 2, null}); } /** */ private static Vector newView(int size, int strideKind) { final Matrix parent = new DenseLocalOnHeapMatrix(size, size); return new MatrixVectorView(parent, 0, 0, strideKind != 1 ? 1 : 0, strideKind != 0 ? 1 : 0); } } /** */ private static class VectorSizesExtraFixture<T> implements Iterable<Vector> { /** */ private final Supplier<VectorSizesExtraIterator<T>> iter; /** */ private final AtomicReference<String> ctxDescrHolder = new AtomicReference<>("Iterator not started."); /** */ VectorSizesExtraFixture(String vectorKind, BiFunction<Integer, T, Vector> ctor, String extraParamName, T[] extras) { iter = () -> new VectorSizesExtraIterator<>(vectorKind, ctor, ctxDescrHolder::set, extraParamName, extras); } /** {@inheritDoc} */ @NotNull @Override public Iterator<Vector> iterator() { return iter.get(); } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return ctxDescrHolder.get(); } } /** */ private static abstract class VectorSizesFixture implements Iterable<Vector> { /** */ private final Supplier<VectorSizesIterator> iter; /** */ private final AtomicReference<String> ctxDescrHolder = new AtomicReference<>("Iterator not started."); /** */ VectorSizesFixture(String vectorKind, Function<Integer, Vector> ctor) { iter = () -> new VectorSizesIterator(vectorKind, ctor, ctxDescrHolder::set); } /** {@inheritDoc} */ @NotNull @Override public Iterator<Vector> iterator() { return iter.get(); } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return ctxDescrHolder.get(); } } /** */ private static class VectorSizesExtraIterator<T> extends VectorSizesIterator { /** */ private final T[] extras; /** */ private int extraIdx = 0; /** */ private final BiFunction<Integer, T, Vector> ctor; /** */ private final String extraParamName; /** * @param vectorKind Descriptive name to use for context logging. * @param ctor Constructor for objects to iterate over. * @param ctxDescrConsumer Context logging consumer. * @param extraParamName Name of extra parameter to iterate over. * @param extras Array of extra parameter values to iterate over. */ VectorSizesExtraIterator(String vectorKind, BiFunction<Integer, T, Vector> ctor, Consumer<String> ctxDescrConsumer, String extraParamName, T[] extras) { super(vectorKind, null, ctxDescrConsumer); this.ctor = ctor; this.extraParamName = extraParamName; this.extras = extras; } /** {@inheritDoc} */ @Override public boolean hasNext() { return super.hasNext() && hasNextExtra(extraIdx); } /** {@inheritDoc} */ @Override void nextIdx() { assert extras[extraIdx] != null : "Index(es) out of bound at " + VectorSizesExtraIterator.this; if (hasNextExtra(extraIdx + 1)) { extraIdx++; return; } extraIdx = 0; super.nextIdx(); } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return "{" + super.toString() + ", " + extraParamName + "=" + extras[extraIdx] + '}'; } /** {@inheritDoc} */ @Override BiFunction<Integer, Integer, Vector> ctor() { return (size, delta) -> ctor.apply(size + delta, extras[extraIdx]); } /** */ void selfTest() { final Set<Integer> extraIdxs = new HashSet<>(); int cnt = 0; while (hasNext()) { assertNotNull("Expect not null vector at " + this, next()); if (extras[extraIdx] != null) extraIdxs.add(extraIdx); cnt++; } assertEquals("Extra param tested", extraIdxs.size(), extras.length - 1); assertEquals("Combinations tested mismatch.", 7 * 3 * (extras.length - 1), cnt); } /** */ private boolean hasNextExtra(int idx) { return extras[idx] != null; } } /** */ private static class VectorSizesIterator extends TwoParamsIterator<Integer, Integer> { /** */ private final Function<Integer, Vector> ctor; /** */ VectorSizesIterator(String vectorKind, Function<Integer, Vector> ctor, Consumer<String> ctxDescrConsumer) { super(vectorKind, null, ctxDescrConsumer, "size", new Integer[] {2, 4, 8, 16, 32, 64, 128, null}, "size delta", new Integer[] {-1, 0, 1, null}); this.ctor = ctor; } /** {@inheritDoc} */ @Override BiFunction<Integer, Integer, Vector> ctor() { return (size, delta) -> ctor.apply(size + delta); } } /** */ private static class TwoParamsIterator<T, U> implements Iterator<Vector> { /** */ private final T params1[]; /** */ private final U params2[]; /** */ private final String vectorKind; /** */ private final String param1Name; /** */ private final String param2Name; /** */ private final BiFunction<T, U, Vector> ctor; /** */ private final Consumer<String> ctxDescrConsumer; /** */ private int param1Idx = 0; /** */ private int param2Idx = 0; /** */ TwoParamsIterator(String vectorKind, BiFunction<T, U, Vector> ctor, Consumer<String> ctxDescrConsumer, String param1Name, T[] params1, String param2Name, U[] params2) { this.param1Name = param1Name; this.params1 = params1; this.param2Name = param2Name; this.params2 = params2; this.vectorKind = vectorKind; this.ctor = ctor; this.ctxDescrConsumer = ctxDescrConsumer; } /** {@inheritDoc} */ @Override public boolean hasNext() { return hasNextParam1(param1Idx) && hasNextParam2(param2Idx); } /** {@inheritDoc} */ @Override public Vector next() { if (!hasNext()) throw new NoSuchElementException(TwoParamsIterator.this.toString()); if (ctxDescrConsumer != null) ctxDescrConsumer.accept(toString()); Vector res = ctor().apply(params1[param1Idx], params2[param2Idx]); nextIdx(); return res; } /** */ void selfTest() { final Set<Integer> sizeIdxs = new HashSet<>(), deltaIdxs = new HashSet<>(); int cnt = 0; while (hasNext()) { assertNotNull("Expect not null vector at " + this, next()); if (params1[param1Idx] != null) sizeIdxs.add(param1Idx); if (params2[param2Idx] != null) deltaIdxs.add(param2Idx); cnt++; } assertEquals("Sizes tested mismatch.", sizeIdxs.size(), params1.length - 1); assertEquals("Deltas tested", deltaIdxs.size(), params2.length - 1); assertEquals("Combinations tested mismatch.", (params1.length - 1) * (params2.length - 1), cnt); } /** IMPL NOTE override in subclasses if needed */ void nextIdx() { assert params1[param1Idx] != null && params2[param2Idx] != null : "Index(es) out of bound at " + TwoParamsIterator.this; if (hasNextParam2(param2Idx + 1)) { param2Idx++; return; } param2Idx = 0; param1Idx++; } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return vectorKind + "{" + param1Name + "=" + params1[param1Idx] + ", " + param2Name + "=" + params2[param2Idx] + '}'; } /** IMPL NOTE override in subclasses if needed */ BiFunction<T, U, Vector> ctor() { return ctor; } /** */ private boolean hasNextParam1(int idx) { return params1[idx] != null; } /** */ private boolean hasNextParam2(int idx) { return params2[idx] != null; } } /** Delegating vector with dense local onheap vector */ private static class DelegatingVectorFixture implements Iterable<Vector> { /** */ private final Supplier<VectorSizesExtraIterator<Boolean>> iter; /** */ private final AtomicReference<String> ctxDescrHolder = new AtomicReference<>("Iterator not started."); /** */ DelegatingVectorFixture() { iter = () -> new VectorSizesExtraIterator<>("DelegatingVector with DenseLocalOnHeapVector", (size, shallowCp) -> new DelegatingVector(new DenseLocalOnHeapVector(new double[size], shallowCp)), ctxDescrHolder::set, "shallow copy", new Boolean[] {false, true, null}); } /** {@inheritDoc} */ @NotNull @Override public Iterator<Vector> iterator() { return iter.get(); } /** {@inheritDoc} */ @Override public String toString() { // IMPL NOTE index within bounds is expected to be guaranteed by proper code in this class return ctxDescrHolder.get(); } } /** Subclass tweaked for serialization */ private static class FunctionVectorForTest extends FunctionVector { /** */ double[] arr; /** */ double scale; /** */ public FunctionVectorForTest() { // No-op. } /** */ FunctionVectorForTest(double[] arr, double scale) { super(arr.length, idx -> arr[idx] * scale, (idx, value) -> arr[idx] = value / scale); this.arr = arr; this.scale = scale; } /** {@inheritDoc} */ @Override public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); out.writeObject(arr); out.writeDouble(scale); } /** {@inheritDoc} */ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { super.readExternal(in); arr = (double[])in.readObject(); scale = in.readDouble(); setStorage(new FunctionVectorStorage(arr.length, idx -> arr[idx] * scale, (idx, value) -> arr[idx] = value / scale)); } /** {@inheritDoc} */ @Override public int hashCode() { int res = 1; res = res * 37 + Double.hashCode(scale); res = res * 37 + Integer.hashCode(getStorage().size()); return res; } /** {@inheritDoc} */ @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FunctionVectorForTest that = (FunctionVectorForTest)o; return new Double(scale).equals(that.scale) && (arr != null ? Arrays.equals(arr, that.arr) : that.arr == null); } } /** */ private static class SparseLocalOffHeapVectorFixture extends VectorSizesFixture { /** */ SparseLocalOffHeapVectorFixture() { super("SparseLocalOffHeapVector", SparseLocalOffHeapVector::new); } } }