package org.la4j.operation.ooplace;
import org.la4j.iterator.MatrixIterator;
import org.la4j.iterator.VectorIterator;
import org.la4j.Matrices;
import org.la4j.Matrix;
import org.la4j.matrix.DenseMatrix;
import org.la4j.matrix.ColumnMajorSparseMatrix;
import org.la4j.matrix.RowMajorSparseMatrix;
import org.la4j.operation.MatrixMatrixOperation;
import org.la4j.Vector;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
public class OoPlaceMatricesMultiplication extends MatrixMatrixOperation<Matrix> {
@Override
public Matrix apply(DenseMatrix a, DenseMatrix b) {
Matrix result = a.blankOfShape(a.rows(), b.columns());
for (int j = 0; j < b.columns(); j++) {
Vector column = b.getColumn(j);
for (int i = 0; i < a.rows(); i++) {
double acc = 0.0;
for (int k = 0; k < a.columns(); k++) {
acc += a.get(i, k) * column.get(k);
}
result.set(i, j, acc);
}
}
return result;
}
@Override
public Matrix apply(DenseMatrix a, RowMajorSparseMatrix b) {
Matrix result = ColumnMajorSparseMatrix.zero(a.rows(), b.columns());
MatrixIterator it = b.nonZeroRowMajorIterator();
while (it.hasNext()) {
double x = it.next();
int i = it.rowIndex();
int j = it.columnIndex();
for (int k = 0; k < a.rows(); k++) {
result.updateAt(k, j, Matrices.asPlusFunction(x * a.get(k, i)));
}
}
return result;
}
@Override
public Matrix apply(DenseMatrix a, ColumnMajorSparseMatrix b) {
Matrix result = b.blankOfShape(a.rows(), b.columns());
Iterator<Integer> nzColumns = b.iteratorOrNonZeroColumns();
while (nzColumns.hasNext()) {
int j = nzColumns.next();
for (int i = 0; i < a.rows(); i++) {
double acc = 0.0;
VectorIterator it = b.nonZeroIteratorOfColumn(j);
while (it.hasNext()) {
double x = it.next();
acc += x * a.get(i, it.index());
}
result.set(i, j, acc);
}
}
return result;
}
@Override
public Matrix apply(RowMajorSparseMatrix a, DenseMatrix b) {
Matrix result = a.blankOfShape(a.rows(), b.columns());
Iterator<Integer> nzRows = a.iteratorOfNonZeroRows();
while (nzRows.hasNext()) {
int i = nzRows.next();
for (int j = 0; j < b.columns(); j++) {
double acc = 0.0;
VectorIterator it = a.nonZeroIteratorOfRow(i);
while (it.hasNext()) {
double x = it.next();
acc += x * b.get(it.index(), j);
}
result.set(i, j, acc);
}
}
return result;
}
@Override
public Matrix apply(RowMajorSparseMatrix a, RowMajorSparseMatrix b) {
// TODO: Can we do it w/o updateAt?
Matrix result = a.blankOfShape(a.rows(), b.columns());
MatrixIterator these = a.nonZeroRowMajorIterator();
while (these.hasNext()) {
double x = these.next();
int i = these.rowIndex();
int j = these.columnIndex();
VectorIterator those = b.nonZeroIteratorOfRow(j);
while (those.hasNext()) {
double y = those.next();
int k = those.index();
result.updateAt(i, k, Matrices.asPlusFunction(x * y));
}
}
return result;
}
@Override
public Matrix apply(RowMajorSparseMatrix a, ColumnMajorSparseMatrix b) {
Matrix result = a.blankOfShape(a.rows(), b.columns());
Iterator<Integer> nzRows = a.iteratorOfNonZeroRows();
Iterator<Integer> nzColumnsIt = b.iteratorOrNonZeroColumns();
List<Integer> nzColumns = new ArrayList<Integer>();
while(nzColumnsIt.hasNext()) {
nzColumns.add(nzColumnsIt.next());
}
while(nzRows.hasNext()) {
int i = nzRows.next();
for (int j: nzColumns) {
result.set(i, j, a.nonZeroIteratorOfRow(i)
.innerProduct(b.nonZeroIteratorOfColumn(j)));
}
}
return result;
}
@Override
public Matrix apply(ColumnMajorSparseMatrix a, DenseMatrix b) {
Matrix result = a.blankOfShape(a.rows(), b.columns());
MatrixIterator it = a.nonZeroColumnMajorIterator();
while (it.hasNext()) {
double x = it.next();
int i = it.rowIndex();
int j = it.columnIndex();
for (int k = 0; k < b.columns(); k++) {
result.updateAt(i, k, Matrices.asPlusFunction(x * b.get(j, k)));
}
}
return result;
}
@Override
public Matrix apply(ColumnMajorSparseMatrix a, RowMajorSparseMatrix b) {
// TODO: Might be improved a bit.
Matrix result = b.blankOfShape(a.rows(), b.columns());
MatrixIterator these = a.nonZeroColumnMajorIterator();
while (these.hasNext()) {
double x = these.next();
int i = these.rowIndex();
int j = these.columnIndex();
VectorIterator those = b.nonZeroIteratorOfRow(j);
while (those.hasNext()) {
double y = those.next();
int k = those.index();
result.updateAt(i, k, Matrices.asPlusFunction(x * y));
}
}
return result;
}
@Override
public Matrix apply(ColumnMajorSparseMatrix a, ColumnMajorSparseMatrix b) {
// TODO: Might be improved a bit.
Matrix result = a.blankOfShape(a.rows(), b.columns());
MatrixIterator these = b.nonZeroColumnMajorIterator();
while (these.hasNext()) {
double x = these.next();
int i = these.rowIndex();
int j = these.columnIndex();
VectorIterator those = a.nonZeroIteratorOfColumn(i);
while (those.hasNext()) {
double y = those.next();
int k = those.index();
result.updateAt(k, j, Matrices.asPlusFunction(x * y));
}
}
return result;
}
@Override
public void ensureApplicableTo(Matrix a, Matrix b) {
if (a.columns() != b.rows()) {
throw new IllegalArgumentException(
"The number of rows in the left-hand matrix should be equal to the number of " +
"columns in the right-hand matrix: " + a.rows() + " does not equal to " + b.columns() + "."
);
}
}
}