/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls.matrix;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.ignite.lang.IgniteUuid;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.MatrixStorage;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.decompositions.LUDecomposition;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.exceptions.ColumnIndexException;
import org.apache.ignite.ml.math.exceptions.RowIndexException;
import org.apache.ignite.ml.math.functions.Functions;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IntIntToDoubleFunction;
import org.apache.ignite.ml.math.impls.vector.MatrixVectorView;
/**
* This class provides a helper implementation of the {@link Matrix}
* interface to minimize the effort required to implement it.
* Subclasses may override some of the implemented methods if a more
* specific or optimized implementation is desirable.
*
* TODO: add row/column optimization.
*/
public abstract class AbstractMatrix implements Matrix {
// Stochastic sparsity analysis.
/** */
private static final double Z95 = 1.959964;
/** */
private static final double Z80 = 1.281552;
/** */
private static final int MAX_SAMPLES = 500;
/** */
private static final int MIN_SAMPLES = 15;
/** Cached minimum element. */
private Element minElm;
/** Cached maximum element. */
private Element maxElm = null;
/** Matrix storage implementation. */
private MatrixStorage sto;
/** Meta attributes storage. */
private Map<String, Object> meta = new HashMap<>();
/** Matrix's GUID. */
private IgniteUuid guid = IgniteUuid.randomUuid();
/**
* @param sto Backing {@link MatrixStorage}.
*/
public AbstractMatrix(MatrixStorage sto) {
this.sto = sto;
}
/**
*
*/
public AbstractMatrix() {
// No-op.
}
/**
* @param sto Backing {@link MatrixStorage}.
*/
protected void setStorage(MatrixStorage sto) {
assert sto != null;
this.sto = sto;
}
/**
* @param row Row index in the matrix.
* @param col Column index in the matrix.
* @param v Value to set.
*/
protected void storageSet(int row, int col, double v) {
sto.set(row, col, v);
// Reset cached values.
minElm = maxElm = null;
}
/**
* @param row Row index in the matrix.
* @param col Column index in the matrix.
*/
protected double storageGet(int row, int col) {
return sto.get(row, col);
}
/** {@inheritDoc} */
@Override public Element maxElement() {
if (maxElm == null) {
double max = Double.NEGATIVE_INFINITY;
int row = 0, col = 0;
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++) {
double d = storageGet(x, y);
if (d > max) {
max = d;
row = x;
col = y;
}
}
maxElm = mkElement(row, col);
}
return maxElm;
}
/** {@inheritDoc} */
@Override public Element minElement() {
if (minElm == null) {
double min = Double.MAX_VALUE;
int row = 0, col = 0;
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++) {
double d = storageGet(x, y);
if (d < min) {
min = d;
row = x;
col = y;
}
}
minElm = mkElement(row, col);
}
return minElm;
}
/** {@inheritDoc} */
@Override public double maxValue() {
return maxElement().get();
}
/** {@inheritDoc} */
@Override public double minValue() {
return minElement().get();
}
/**
* @param row Row index in the matrix.
* @param col Column index in the matrix.
*/
private Element mkElement(int row, int col) {
return new Element() {
/** {@inheritDoc} */
@Override public double get() {
return storageGet(row, col);
}
/** {@inheritDoc} */
@Override public int row() {
return row;
}
/** {@inheritDoc} */
@Override public int column() {
return col;
}
/** {@inheritDoc} */
@Override public void set(double d) {
storageSet(row, col, d);
}
};
}
/** {@inheritDoc} */
@Override public Element getElement(int row, int col) {
return mkElement(row, col);
}
/** {@inheritDoc} */
@Override public Matrix swapRows(int row1, int row2) {
checkRowIndex(row1);
checkRowIndex(row2);
int cols = columnSize();
for (int y = 0; y < cols; y++) {
double v = getX(row1, y);
setX(row1, y, getX(row2, y));
setX(row2, y, v);
}
return this;
}
/** {@inheritDoc} */
@Override public Matrix swapColumns(int col1, int col2) {
checkColumnIndex(col1);
checkColumnIndex(col2);
int rows = rowSize();
for (int x = 0; x < rows; x++) {
double v = getX(x, col1);
setX(x, col1, getX(x, col2));
setX(x, col2, v);
}
return this;
}
/** {@inheritDoc} */
@Override public MatrixStorage getStorage() {
return sto;
}
/** {@inheritDoc} */
@Override public boolean isSequentialAccess() {
return sto.isSequentialAccess();
}
/** {@inheritDoc} */
@Override public boolean isDense() {
return sto.isDense();
}
/** {@inheritDoc} */
@Override public boolean isRandomAccess() {
return sto.isRandomAccess();
}
/** {@inheritDoc} */
@Override public boolean isDistributed() {
return sto.isDistributed();
}
/** {@inheritDoc} */
@Override public boolean isArrayBased() {
return sto.isArrayBased();
}
/**
* Check row index bounds.
*
* @param row Row index.
*/
private void checkRowIndex(int row) {
if (row < 0 || row >= rowSize())
throw new RowIndexException(row);
}
/**
* Check column index bounds.
*
* @param col Column index.
*/
private void checkColumnIndex(int col) {
if (col < 0 || col >= columnSize())
throw new ColumnIndexException(col);
}
/**
* Check column and row index bounds.
*
* @param row Row index.
* @param col Column index.
*/
protected void checkIndex(int row, int col) {
checkRowIndex(row);
checkColumnIndex(col);
}
/** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(sto);
out.writeObject(meta);
out.writeObject(guid);
}
/** {@inheritDoc} */
@Override public Map<String, Object> getMetaStorage() {
return meta;
}
/** {@inheritDoc} */
@SuppressWarnings("unchecked")
@Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
sto = (MatrixStorage)in.readObject();
meta = (Map<String, Object>)in.readObject();
guid = (IgniteUuid)in.readObject();
}
/** {@inheritDoc} */
@Override public Matrix assign(double val) {
if (sto.isArrayBased())
for (double[] column : sto.data())
Arrays.fill(column, val);
else {
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
storageSet(x, y, val);
}
return this;
}
/** {@inheritDoc} */
@Override public Matrix assign(IntIntToDoubleFunction fun) {
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
storageSet(x, y, fun.apply(x, y));
return this;
}
/** */
private void checkCardinality(Matrix mtx) {
checkCardinality(mtx.rowSize(), mtx.columnSize());
}
/** */
private void checkCardinality(int rows, int cols) {
if (rows != rowSize())
throw new CardinalityException(rowSize(), rows);
if (cols != columnSize())
throw new CardinalityException(columnSize(), cols);
}
/** {@inheritDoc} */
@Override public Matrix assign(double[][] vals) {
checkCardinality(vals.length, vals[0].length);
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
storageSet(x, y, vals[x][y]);
return this;
}
/** {@inheritDoc} */
@Override public Matrix assign(Matrix mtx) {
checkCardinality(mtx);
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
storageSet(x, y, mtx.getX(x, y));
return this;
}
/** {@inheritDoc} */
@Override public Matrix map(IgniteDoubleFunction<Double> fun) {
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
storageSet(x, y, fun.apply(storageGet(x, y)));
return this;
}
/** {@inheritDoc} */
@Override public Matrix map(Matrix mtx, IgniteBiFunction<Double, Double, Double> fun) {
checkCardinality(mtx);
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
storageSet(x, y, fun.apply(storageGet(x, y), mtx.getX(x, y)));
return this;
}
/** {@inheritDoc} */
@Override public Matrix assignColumn(int col, Vector vec) {
checkColumnIndex(col);
int rows = rowSize();
for (int x = 0; x < rows; x++)
storageSet(x, col, vec.getX(x));
return this;
}
/** {@inheritDoc} */
@Override public Matrix assignRow(int row, Vector vec) {
checkRowIndex(row);
int cols = columnSize();
if (cols != vec.size())
throw new CardinalityException(cols, vec.size());
if (sto.isArrayBased() && vec.getStorage().isArrayBased())
System.arraycopy(vec.getStorage().data(), 0, sto.data()[row], 0, cols);
else
for (int y = 0; y < cols; y++)
storageSet(row, y, vec.getX(y));
return this;
}
/** {@inheritDoc} */
@Override public Vector foldRows(IgniteFunction<Vector, Double> fun) {
int rows = rowSize();
Vector vec = likeVector(rows);
for (int i = 0; i < rows; i++)
vec.setX(i, fun.apply(viewRow(i)));
return vec;
}
/** {@inheritDoc} */
@Override public Vector foldColumns(IgniteFunction<Vector, Double> fun) {
int cols = columnSize();
Vector vec = likeVector(cols);
for (int i = 0; i < cols; i++)
vec.setX(i, fun.apply(viewColumn(i)));
return vec;
}
/** {@inheritDoc} */
@Override public <T> T foldMap(IgniteBiFunction<T, Double, T> foldFun, IgniteDoubleFunction<Double> mapFun,
T zeroVal) {
T res = zeroVal;
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
res = foldFun.apply(res, mapFun.apply(storageGet(x, y)));
return res;
}
/** {@inheritDoc} */
@Override public int columnSize() {
return sto.columnSize();
}
/** {@inheritDoc} */
@Override public int rowSize() {
return sto.rowSize();
}
/** {@inheritDoc} */
@Override public double determinant() {
//TODO: This decomposition should be cached
LUDecomposition dec = new LUDecomposition(this);
double res = dec.determinant();
dec.destroy();
return res;
}
/** {@inheritDoc} */
@Override public Matrix inverse() {
if (rowSize() != columnSize())
throw new CardinalityException(rowSize(), columnSize());
//TODO: This decomposition should be cached
LUDecomposition dec = new LUDecomposition(this);
Matrix res = dec.solve(likeIdentity());
dec.destroy();
return res;
}
/** */
protected Matrix likeIdentity() {
int n = rowSize();
Matrix res = like(n, n);
for (int i = 0; i < n; i++)
res.setX(i, i, 1.0);
return res;
}
/** {@inheritDoc} */
@Override public Matrix divide(double d) {
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
setX(x, y, getX(x, y) / d);
return this;
}
/** {@inheritDoc} */
@Override public double get(int row, int col) {
checkIndex(row, col);
return storageGet(row, col);
}
/** {@inheritDoc} */
@Override public double getX(int row, int col) {
return storageGet(row, col);
}
/** {@inheritDoc} */
@Override public Matrix minus(Matrix mtx) {
int rows = rowSize();
int cols = columnSize();
checkCardinality(rows, cols);
Matrix res = like(rows, cols);
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
res.setX(x, y, getX(x, y) - mtx.getX(x, y));
return res;
}
/** {@inheritDoc} */
@Override public Matrix plus(double x) {
Matrix cp = copy();
cp.map(Functions.plus(x));
return cp;
}
/** {@inheritDoc} */
@Override public Matrix plus(Matrix mtx) {
int rows = rowSize();
int cols = columnSize();
checkCardinality(rows, cols);
Matrix res = like(rows, cols);
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
res.setX(x, y, getX(x, y) + mtx.getX(x, y));
return res;
}
/** {@inheritDoc} */
@Override public IgniteUuid guid() {
return guid;
}
/** {@inheritDoc} */
@Override public Matrix set(int row, int col, double val) {
checkIndex(row, col);
storageSet(row, col, val);
return this;
}
/** {@inheritDoc} */
@Override public Matrix setRow(int row, double[] data) {
checkRowIndex(row);
int cols = columnSize();
if (cols != data.length)
throw new CardinalityException(cols, data.length);
if (sto.isArrayBased())
System.arraycopy(data, 0, sto.data()[row], 0, cols);
else
for (int y = 0; y < cols; y++)
setX(row, y, data[y]);
return this;
}
/** {@inheritDoc} */
@Override public Matrix setColumn(int col, double[] data) {
checkColumnIndex(col);
int rows = rowSize();
if (rows != data.length)
throw new CardinalityException(rows, data.length);
for (int x = 0; x < rows; x++)
setX(x, col, data[x]);
return this;
}
/** {@inheritDoc} */
@Override public Matrix setX(int row, int col, double val) {
storageSet(row, col, val);
return this;
}
/** {@inheritDoc} */
@Override public Matrix times(double x) {
Matrix cp = copy();
cp.map(Functions.mult(x));
return cp;
}
/** {@inheritDoc} */
@Override public double maxAbsRowSumNorm() {
double max = 0.0;
int rows = rowSize();
int cols = columnSize();
for (int x = 0; x < rows; x++) {
double sum = 0;
for (int y = 0; y < cols; y++)
sum += Math.abs(getX(x, y));
if (sum > max)
max = sum;
}
return max;
}
/** {@inheritDoc} */
@Override public Vector times(Vector vec) {
int cols = columnSize();
if (cols != vec.size())
throw new CardinalityException(cols, vec.size());
int rows = rowSize();
Vector res = likeVector(rows);
for (int x = 0; x < rows; x++)
res.setX(x, vec.dot(viewRow(x)));
return res;
}
/** {@inheritDoc} */
@Override public Matrix times(Matrix mtx) {
int cols = columnSize();
if (cols != mtx.rowSize())
throw new CardinalityException(cols, mtx.rowSize());
int rows = rowSize();
int mtxCols = mtx.columnSize();
Matrix res = like(rows, mtxCols);
for (int x = 0; x < rows; x++)
for (int y = 0; y < mtxCols; y++) {
double sum = 0.0;
for (int k = 0; k < cols; k++)
sum += getX(x, k) * mtx.getX(k, y);
res.setX(x, y, sum);
}
return res;
}
/** {@inheritDoc} */
@Override public double sum() {
int rows = rowSize();
int cols = columnSize();
double sum = 0.0;
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
sum += getX(x, y);
return sum;
}
/** {@inheritDoc} */
@Override public Matrix transpose() {
int rows = rowSize();
int cols = columnSize();
Matrix mtx = like(cols, rows);
for (int x = 0; x < rows; x++)
for (int y = 0; y < cols; y++)
mtx.setX(y, x, getX(x, y));
return mtx;
}
/** {@inheritDoc} */
@Override public boolean density(double threshold) {
assert threshold >= 0.0 && threshold <= 1.0;
int n = MIN_SAMPLES;
int rows = rowSize();
int cols = columnSize();
double mean = 0.0;
double pq = threshold * (1 - threshold);
Random rnd = new Random();
for (int i = 0; i < MIN_SAMPLES; i++)
if (getX(rnd.nextInt(rows), rnd.nextInt(cols)) != 0.0)
mean++;
mean /= MIN_SAMPLES;
double iv = Z80 * Math.sqrt(pq / n);
if (mean < threshold - iv)
return false; // Sparse.
else if (mean > threshold + iv)
return true; // Dense.
while (n < MAX_SAMPLES) {
// Determine upper bound we may need for 'n' to likely relinquish the uncertainty.
// Here, we use confidence interval formula but solved for 'n'.
double ivX = Math.max(Math.abs(threshold - mean), 1e-11);
double stdErr = ivX / Z80;
double nX = Math.min(Math.max((int)Math.ceil(pq / (stdErr * stdErr)), n), MAX_SAMPLES) - n;
if (nX < 1.0) // IMPL NOTE this can happen with threshold 1.0
nX = 1.0;
double meanNext = 0.0;
for (int i = 0; i < nX; i++)
if (getX(rnd.nextInt(rows), rnd.nextInt(cols)) != 0.0)
meanNext++;
mean = (n * mean + meanNext) / (n + nX);
n += nX;
// Are we good now?
iv = Z80 * Math.sqrt(pq / n);
if (mean < threshold - iv)
return false; // Sparse.
else if (mean > threshold + iv)
return true; // Dense.
}
return mean > threshold; // Dense if mean > threshold.
}
/** {@inheritDoc} */
@Override public Matrix viewPart(int[] off, int[] size) {
return new MatrixView(this, off[0], off[1], size[0], size[1]);
}
/** {@inheritDoc} */
@Override public Matrix viewPart(int rowOff, int rows, int colOff, int cols) {
return viewPart(new int[] {rowOff, colOff}, new int[] {rows, cols});
}
/** {@inheritDoc} */
@Override public Vector viewRow(int row) {
return new MatrixVectorView(this, row, 0, 0, 1);
}
/** {@inheritDoc} */
@Override public Vector viewColumn(int col) {
return new MatrixVectorView(this, 0, col, 1, 0);
}
/** {@inheritDoc} */
@Override public Vector viewDiagonal() {
return new MatrixVectorView(this, 0, 0, 1, 1);
}
/** {@inheritDoc} */
@Override public void destroy() {
getStorage().destroy();
}
/** {@inheritDoc} */
@Override public Matrix copy() {
Matrix cp = like(rowSize(), columnSize());
cp.assign(this);
return cp;
}
/** {@inheritDoc} */
@Override public int hashCode() {
int res = 1;
res = res * 37 + guid.hashCode();
res = res * 37 + sto.hashCode();
res = res * 37 + meta.hashCode();
return res;
}
/**
* {@inheritDoc}
*
* We ignore guid's for comparisons.
*/
@Override public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
AbstractMatrix that = (AbstractMatrix)o;
MatrixStorage sto = getStorage();
return (sto != null ? sto.equals(that.getStorage()) : that.getStorage() == null);
}
}