/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.matrix;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.impls.storage.matrix.FunctionMatrixStorage;
import org.jetbrains.annotations.NotNull;
/** */
class MatrixImplementationFixtures {
/** */
private static final List<Supplier<Iterable<Matrix>>> suppliers = Arrays.asList(
(Supplier<Iterable<Matrix>>)DenseLocalOnHeapMatrixFixture::new,
(Supplier<Iterable<Matrix>>)DenseLocalOffHeapMatrixFixture::new,
(Supplier<Iterable<Matrix>>)RandomMatrixFixture::new,
(Supplier<Iterable<Matrix>>)SparseLocalOnHeapMatrixFixture::new,
(Supplier<Iterable<Matrix>>)PivotedMatrixViewFixture::new,
(Supplier<Iterable<Matrix>>)MatrixViewFixture::new,
(Supplier<Iterable<Matrix>>)FunctionMatrixFixture::new,
(Supplier<Iterable<Matrix>>)DiagonalMatrixFixture::new,
(Supplier<Iterable<Matrix>>)TransposedMatrixViewFixture::new
);
/** */
void consumeSampleMatrix(BiConsumer<Matrix, String> consumer) {
for (Supplier<Iterable<Matrix>> fixtureSupplier : suppliers) {
final Iterable<Matrix> fixture = fixtureSupplier.get();
for (Matrix matrix : fixture) {
consumer.accept(matrix, fixture.toString());
matrix.destroy();
}
}
}
/** */
private static class DenseLocalOnHeapMatrixFixture extends MatrixSizeIterator {
/** */
DenseLocalOnHeapMatrixFixture() {
super(DenseLocalOnHeapMatrix::new, "DenseLocalOnHeapMatrix");
}
}
/** */
private static class DenseLocalOffHeapMatrixFixture extends MatrixSizeIterator {
/** */
DenseLocalOffHeapMatrixFixture() {
super(DenseLocalOffHeapMatrix::new, "DenseLocalOffHeapMatrix");
}
}
/** */
private static class RandomMatrixFixture extends MatrixSizeIterator {
/** */
RandomMatrixFixture() {
super(RandomMatrix::new, "RandomMatrix");
}
}
/** */
private static class SparseLocalOnHeapMatrixFixture extends MatrixSizeIterator {
/** */
SparseLocalOnHeapMatrixFixture() {
super(SparseLocalOnHeapMatrix::new, "SparseLocalOnHeapMatrix");
}
}
/** */
private static class PivotedMatrixViewFixture extends WrapperMatrixIterator {
/** */
PivotedMatrixViewFixture() {
super(PivotedMatrixView::new, "PivotedMatrixView over DenseLocalOnHeapMatrix");
}
}
/** */
private static class MatrixViewFixture extends WrapperMatrixIterator {
/** */
MatrixViewFixture() {
super((matrix) -> new MatrixView(matrix, 0, 0, matrix.rowSize(), matrix.columnSize()),
"MatrixView over DenseLocalOnHeapMatrix");
}
}
/** */
private static class FunctionMatrixFixture extends WrapperMatrixIterator {
/** */
FunctionMatrixFixture() {
super(FunctionMatrixForTest::new, "FunctionMatrix wrapping DenseLocalOnHeapMatrix");
}
}
/** */
private static class DiagonalMatrixFixture extends DiagonalIterator {
/** */
DiagonalMatrixFixture() {
super(DenseLocalOnHeapMatrix::new, "DiagonalMatrix over DenseLocalOnHeapMatrix");
}
/** {@inheritDoc} */
@NotNull
@Override public Iterator<Matrix> iterator() {
return new Iterator<Matrix>() {
/** {@inheritDoc} */
@Override public boolean hasNext() {
return hasNextSize(getSizeIdx());
}
/** {@inheritDoc} */
@Override public Matrix next() {
assert getSize(getSizeIdx()) == 1 : "Only size 1 allowed for diagonal matrix fixture.";
Matrix matrix = getConstructor().apply(getSize(getSizeIdx()), getSize(getSizeIdx()));
nextIdx();
return new DiagonalMatrix(matrix);
}
};
}
}
/** */
private static class TransposedMatrixViewFixture extends WrapperMatrixIterator {
/** */
TransposedMatrixViewFixture() {
super(TransposedMatrixView::new, "TransposedMatrixView over DenseLocalOnHeapMatrix");
}
}
/** */
private static abstract class DiagonalIterator implements Iterable<Matrix> {
/** */
private final Integer[] sizes = new Integer[] {1, null};
/** */
private int sizeIdx = 0;
/** */
private BiFunction<Integer, Integer, ? extends Matrix> constructor;
/** */
private String desc;
/** */
DiagonalIterator(BiFunction<Integer, Integer, ? extends Matrix> constructor, String desc) {
this.constructor = constructor;
this.desc = desc;
}
/** */
public BiFunction<Integer, Integer, ? extends Matrix> getConstructor() {
return constructor;
}
/** */
int getSizeIdx() {
return sizeIdx;
}
/** */
@Override public String toString() {
return desc + "{rows=" + sizes[sizeIdx] + ", cols=" + sizes[sizeIdx] + "}";
}
/** */
boolean hasNextSize(int idx) {
return sizes[idx] != null;
}
/** */
Integer getSize(int idx) {
return sizes[idx];
}
/** */
void nextIdx() {
sizeIdx++;
}
}
/** */
private static class WrapperMatrixIterator extends MatrixSizeIterator {
/** */
private final Function<Matrix, Matrix> wrapperCtor;
/** */
WrapperMatrixIterator(Function<Matrix, Matrix> wrapperCtor, String desc) {
super(DenseLocalOnHeapMatrix::new, desc);
this.wrapperCtor = wrapperCtor;
}
/** {@inheritDoc} */
@NotNull
@Override public Iterator<Matrix> iterator() {
return new Iterator<Matrix>() {
/** {@inheritDoc} */
@Override public boolean hasNext() {
return hasNextCol(getSizeIdx()) && hasNextRow(getSizeIdx());
}
/** {@inheritDoc} */
@Override public Matrix next() {
Matrix matrix = getConstructor().apply(getRow(getSizeIdx()), getCol(getSizeIdx()));
nextIdx();
return wrapperCtor.apply(matrix);
}
};
}
}
/** */
private static class MatrixSizeIterator implements Iterable<Matrix> {
/** */
private final Integer[] rows = new Integer[] {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 512, 1024, null};
/** */
private final Integer[] cols = new Integer[] {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1024, 512, null};
/** */
private int sizeIdx = 0;
/** */
private BiFunction<Integer, Integer, ? extends Matrix> constructor;
/** */
private String desc;
/** */
MatrixSizeIterator(BiFunction<Integer, Integer, ? extends Matrix> constructor, String desc) {
this.constructor = constructor;
this.desc = desc;
}
/** */
public BiFunction<Integer, Integer, ? extends Matrix> getConstructor() {
return constructor;
}
/** */
int getSizeIdx() {
return sizeIdx;
}
/** */
@Override public String toString() {
return desc + "{rows=" + rows[sizeIdx] + ", cols=" + cols[sizeIdx] + "}";
}
/** */
boolean hasNextRow(int idx) {
return rows[idx] != null;
}
/** */
boolean hasNextCol(int idx) {
return cols[idx] != null;
}
/** */
Integer getRow(int idx) {
return rows[idx];
}
/** */
int getCol(int idx) {
return cols[idx];
}
/** {@inheritDoc} */
@NotNull
@Override public Iterator<Matrix> iterator() {
return new Iterator<Matrix>() {
/** {@inheritDoc} */
@Override public boolean hasNext() {
return hasNextCol(sizeIdx) && hasNextRow(sizeIdx);
}
/** {@inheritDoc} */
@Override public Matrix next() {
Matrix matrix = constructor.apply(rows[sizeIdx], cols[sizeIdx]);
nextIdx();
return matrix;
}
};
}
/** */
void nextIdx() {
sizeIdx++;
}
}
/** Subclass tweaked for serialization */
private static class FunctionMatrixForTest extends FunctionMatrix {
/** */
Matrix underlying;
/** */
public FunctionMatrixForTest() {
// No-op.
}
/** */
FunctionMatrixForTest(Matrix underlying) {
super(underlying.rowSize(), underlying.columnSize(), underlying::get, underlying::set);
this.underlying = underlying;
}
/** {@inheritDoc} */
@Override public Matrix copy() {
return new FunctionMatrixForTest(underlying);
}
/** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeObject(underlying);
}
/** {@inheritDoc} */
@Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
super.readExternal(in);
underlying = (Matrix)in.readObject();
setStorage(new FunctionMatrixStorage(underlying.rowSize(), underlying.columnSize(),
underlying::get, underlying::set));
}
/** {@inheritDoc} */
@Override public int hashCode() {
int res = 1;
res = res * 37 + underlying.hashCode();
return res;
}
/** {@inheritDoc} */
@Override public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
FunctionMatrixForTest that = (FunctionMatrixForTest)o;
return underlying != null ? underlying.equals(that.underlying) : that.underlying == null;
}
}
}