/*
* Copyright (C) 2008-2015 by Holger Arndt, Frode Carlsen
*
* This file is part of the Universal Java Matrix Package (UJMP).
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* UJMP is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* UJMP is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with UJMP; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package org.ujmp.core.calculation;
import static org.ujmp.core.util.VerifyUtil.verifyTrue;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.DenseMatrix2D;
import org.ujmp.core.Matrix;
import org.ujmp.core.SparseMatrix;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.doublematrix.impl.BlockDenseDoubleMatrix2D;
import org.ujmp.core.doublematrix.impl.BlockMatrixLayout;
import org.ujmp.core.doublematrix.impl.BlockMatrixLayout.BlockOrder;
import org.ujmp.core.doublematrix.impl.BlockMultiply;
import org.ujmp.core.interfaces.HasColumnMajorDoubleArray1D;
import org.ujmp.core.interfaces.HasRowMajorDoubleArray2D;
import org.ujmp.core.util.AbstractPlugin;
import org.ujmp.core.util.UJMPSettings;
import org.ujmp.core.util.VerifyUtil;
import org.ujmp.core.util.concurrent.PFor;
import org.ujmp.core.util.concurrent.UJMPThreadPoolExecutor;
public class Mtimes {
public static int THRESHOLD = 100;
public static final MtimesCalculation<Matrix, Matrix, Matrix> MATRIX = new MtimesMatrix();
public static final MtimesCalculation<DenseMatrix, DenseMatrix, DenseMatrix> DENSEMATRIX = new MtimesDenseMatrix();
public static final MtimesCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> DENSEMATRIX2D = new MtimesDenseMatrix2D();
public static final MtimesCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> DENSEDOUBLEMATRIX2D = new MtimesDenseDoubleMatrix2D();
public static final MtimesCalculation<SparseMatrix, Matrix, Matrix> SPARSEMATRIX1 = new MtimesSparseMatrix1();
public static final MtimesCalculation<Matrix, SparseMatrix, Matrix> SPARSEMATRIX2 = new MtimesSparseMatrix2();
public static final MtimesCalculation<SparseMatrix, SparseMatrix, Matrix> SPARSEMATRIXBOTH = new MtimesSparseMatrixBoth();
public static MtimesCalculation<Matrix, Matrix, Matrix> MTIMES_JBLAS = null;
public static final boolean RESET_BLOCK_ORDER = false;
static {
init();
}
@SuppressWarnings("unchecked")
public static void init() {
try {
AbstractPlugin p = (AbstractPlugin) Class.forName("org.ujmp.jblas.Plugin")
.newInstance();
if (p.isAvailable()) {
MTIMES_JBLAS = (MtimesCalculation<Matrix, Matrix, Matrix>) Class.forName(
"org.ujmp.jblas.calculation.Mtimes").newInstance();
}
} catch (Throwable t) {
}
}
}
class MtimesMatrix implements MtimesCalculation<Matrix, Matrix, Matrix> {
public final void calc(final Matrix source1, final Matrix source2, final Matrix target) {
if (source1.isSparse() && source1 instanceof SparseMatrix && source2.isSparse()
&& source2 instanceof SparseMatrix) {
Mtimes.SPARSEMATRIXBOTH.calc((SparseMatrix) source1, (SparseMatrix) source2, target);
} else if (source1.isSparse() && source1 instanceof SparseMatrix) {
Mtimes.SPARSEMATRIX1.calc((SparseMatrix) source1, source2, target);
} else if (source2.isSparse() && source2 instanceof SparseMatrix) {
Mtimes.SPARSEMATRIX2.calc(source1, (SparseMatrix) source2, target);
} else if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D
&& target instanceof DenseDoubleMatrix2D) {
Mtimes.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1,
(DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
} else if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D
&& target instanceof DenseMatrix2D) {
Mtimes.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2,
(DenseMatrix2D) target);
} else if (source1 instanceof DenseMatrix && source2 instanceof DenseMatrix
&& target instanceof DenseMatrix) {
Mtimes.DENSEMATRIX.calc((DenseMatrix) source1, (DenseMatrix) source2,
(DenseMatrix) target);
} else {
gemm(source1, source2, target);
}
}
private final void gemm(final Matrix A, final Matrix B, final Matrix C) {
VerifyUtil.verify2D(A);
VerifyUtil.verify2D(B);
VerifyUtil.verify2D(C);
final int m1RowCount = (int) A.getRowCount();
final int m1ColumnCount = (int) A.getColumnCount();
final int m2RowCount = (int) B.getRowCount();
final int m2ColumnCount = (int) B.getColumnCount();
VerifyUtil.verifyEquals(m1ColumnCount, m2RowCount, "matrices have wrong sizes");
VerifyUtil.verifyEquals(m1RowCount, C.getRowCount(), "matrices have wrong sizes");
VerifyUtil.verifyEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong sizes");
if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD
&& m2ColumnCount >= Mtimes.THRESHOLD) {
new PFor(0, m2ColumnCount - 1) {
@Override
public void step(int i) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getAsDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol)
* temp, irow, i);
}
}
}
}
};
} else {
for (int i = 0; i < m2ColumnCount; i++) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getAsDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(
C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp,
irow, i);
}
}
}
}
}
}
};
class MtimesDenseMatrix implements MtimesCalculation<DenseMatrix, DenseMatrix, DenseMatrix> {
public final void calc(final DenseMatrix source1, final DenseMatrix source2,
final DenseMatrix target) {
if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D
&& target instanceof DenseMatrix2D) {
Mtimes.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2,
(DenseMatrix2D) target);
} else {
gemm(source1, source2, target);
}
}
private final void gemm(final DenseMatrix A, final DenseMatrix B, final DenseMatrix C) {
VerifyUtil.verify2D(A);
VerifyUtil.verify2D(B);
VerifyUtil.verify2D(C);
final int m1RowCount = (int) A.getRowCount();
final int m1ColumnCount = (int) A.getColumnCount();
final int m2RowCount = (int) B.getRowCount();
final int m2ColumnCount = (int) B.getColumnCount();
VerifyUtil.verifyEquals(m1ColumnCount, m2RowCount, "matrices have wrong sizes");
VerifyUtil.verifyEquals(m1RowCount, C.getRowCount(), "matrices have wrong sizes");
VerifyUtil.verifyEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong sizes");
if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD
&& m2ColumnCount >= Mtimes.THRESHOLD) {
new PFor(0, m2ColumnCount - 1) {
@Override
public void step(int i) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getAsDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol)
* temp, irow, i);
}
}
}
}
};
} else {
for (int i = 0; i < m2ColumnCount; i++) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getAsDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(
C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp,
irow, i);
}
}
}
}
}
}
};
class MtimesSparseMatrix1 implements MtimesCalculation<SparseMatrix, Matrix, Matrix> {
public final void calc(final SparseMatrix source1, final Matrix source2, final Matrix target) {
VerifyUtil.verify2D(source1);
VerifyUtil.verify2D(source2);
VerifyUtil.verify2D(target);
VerifyUtil.verifyEquals(source1.getColumnCount(), source2.getRowCount(),
"matrices have wrong sizes");
VerifyUtil.verifyEquals(target.getRowCount(), source1.getRowCount(),
"matrices have wrong sizes");
VerifyUtil.verifyEquals(target.getColumnCount(), source2.getColumnCount(),
"matrices have wrong sizes");
target.clear();
for (long[] c1 : source1.availableCoordinates()) {
final double v1 = source1.getAsDouble(c1);
if (v1 != 0.0d) {
for (long col2 = source2.getColumnCount(); --col2 != -1;) {
final double v2 = source2.getAsDouble(c1[1], col2);
final double temp = v1 * v2;
if (temp != 0.0d) {
final double v3 = target.getAsDouble(c1[0], col2);
target.setAsDouble(v3 + temp, c1[0], col2);
}
}
}
}
}
};
class MtimesSparseMatrixBoth implements MtimesCalculation<SparseMatrix, SparseMatrix, Matrix> {
public final void calc(final SparseMatrix source1, final SparseMatrix source2,
final Matrix target) {
VerifyUtil.verify2D(source1);
VerifyUtil.verify2D(source2);
VerifyUtil.verify2D(target);
VerifyUtil.verifyEquals(source1.getColumnCount(), source2.getRowCount(),
"matrices have wrong sizes");
VerifyUtil.verifyEquals(target.getRowCount(), source1.getRowCount(),
"matrices have wrong sizes");
VerifyUtil.verifyEquals(target.getColumnCount(), source2.getColumnCount(),
"matrices have wrong sizes");
target.clear();
for (long[] c1 : source1.availableCoordinates()) {
final double v1 = source1.getAsDouble(c1);
if (v1 != 0.0) {
for (long[] c2 : source2.availableCoordinates()) {
if (c2[0] == c1[1]) {
final double v2 = source2.getAsDouble(c2);
if (v1 != 0.0) {
final double temp = v1 * v2;
final double v3 = target.getAsDouble(c1[0], c2[1]);
target.setAsDouble(v3 + temp, c1[0], c2[1]);
}
}
}
}
}
}
};
class MtimesSparseMatrix2 implements MtimesCalculation<Matrix, SparseMatrix, Matrix> {
public final void calc(final Matrix source1, final SparseMatrix source2, final Matrix target) {
VerifyUtil.verify2D(source1);
VerifyUtil.verify2D(source2);
VerifyUtil.verify2D(target);
VerifyUtil.verifyEquals(source1.getColumnCount(), source2.getRowCount(),
"matrices have wrong sizes");
VerifyUtil.verifyEquals(target.getRowCount(), source1.getRowCount(),
"matrices have wrong sizes");
VerifyUtil.verifyEquals(target.getColumnCount(), source2.getColumnCount(),
"matrices have wrong sizes");
target.clear();
for (long[] c2 : source2.availableCoordinates()) {
final double v2 = source2.getAsDouble(c2);
if (v2 != 0.0d) {
for (long row1 = source1.getRowCount(); --row1 != -1;) {
final double v1 = source1.getAsDouble(row1, c2[0]);
final double temp = v1 * v2;
if (temp != 0.0d) {
final double v3 = target.getAsDouble(row1, c2[1]);
target.setAsDouble(v3 + temp, row1, c2[1]);
}
}
}
}
}
};
class MtimesDenseMatrix2D implements MtimesCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> {
public final void calc(final DenseMatrix2D source1, final DenseMatrix2D source2,
final DenseMatrix2D target) {
if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D
&& target instanceof DenseDoubleMatrix2D) {
Mtimes.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1,
(DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
} else {
gemm(source1, source2, target);
}
}
private final void gemm(final DenseMatrix2D A, final DenseMatrix2D B, final DenseMatrix2D C) {
VerifyUtil.verify2D(A);
VerifyUtil.verify2D(B);
VerifyUtil.verify2D(C);
final int m1RowCount = (int) A.getRowCount();
final int m1ColumnCount = (int) A.getColumnCount();
final int m2RowCount = (int) B.getRowCount();
final int m2ColumnCount = (int) B.getColumnCount();
VerifyUtil.verifyEquals(m1ColumnCount, m2RowCount, "matrices have wrong size");
VerifyUtil.verifyEquals(m1RowCount, C.getRowCount(), "matrices have wrong size");
VerifyUtil.verifyEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong size");
if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD
&& m2ColumnCount >= Mtimes.THRESHOLD) {
new PFor(0, m2ColumnCount - 1) {
@Override
public void step(int i) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getAsDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol)
* temp, irow, i);
}
}
}
}
};
} else {
for (int i = 0; i < m2ColumnCount; i++) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getAsDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setAsDouble(
C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp,
irow, i);
}
}
}
}
}
}
};
/**
* Contains matrix multiplication methods for different matrix implementations
*
* @author Holger Arndt
* @author Frode Carlsen
*
*/
class MtimesDenseDoubleMatrix2D implements
MtimesCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> {
public final void calc(final DenseDoubleMatrix2D source1, final DenseDoubleMatrix2D source2,
final DenseDoubleMatrix2D target) {
verifyTrue(source1 != null, "a == null");
verifyTrue(source2 != null, "b == null");
verifyTrue(target != null, "c == null");
verifyTrue(source1.getColumnCount() == source2.getRowCount(), "a.cols!=b.rows");
verifyTrue(source1.getRowCount() == target.getRowCount(), "a.rows!=c.rows");
verifyTrue(source2.getColumnCount() == target.getColumnCount(), "a.cols!=c.cols");
if (source1.getRowCount() >= Mtimes.THRESHOLD
&& source1.getColumnCount() >= Mtimes.THRESHOLD) {
if (Mtimes.MTIMES_JBLAS != null && UJMPSettings.getInstance().isUseJBlas()) {
Mtimes.MTIMES_JBLAS.calc((DenseDoubleMatrix2D) source1,
(DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
} else if (UJMPSettings.getInstance().isUseBlockMatrixMultiply()) {
calcBlockMatrixMultiThreaded(source1, source2, target);
} else if (source1 instanceof HasColumnMajorDoubleArray1D
&& source2 instanceof HasColumnMajorDoubleArray1D
&& target instanceof HasColumnMajorDoubleArray1D) {
calcDoubleArrayMultiThreaded(
((HasColumnMajorDoubleArray1D) source1).getColumnMajorDoubleArray1D(),
(int) source1.getRowCount(), (int) source1.getColumnCount(),
((HasColumnMajorDoubleArray1D) source2).getColumnMajorDoubleArray1D(),
(int) source2.getRowCount(), (int) source2.getColumnCount(),
((HasColumnMajorDoubleArray1D) target).getColumnMajorDoubleArray1D());
} else if (source1 instanceof HasRowMajorDoubleArray2D
&& source2 instanceof HasRowMajorDoubleArray2D
&& target instanceof HasRowMajorDoubleArray2D) {
calcDoubleArray2DMultiThreaded(
((HasRowMajorDoubleArray2D) source1).getRowMajorDoubleArray2D(),
((HasRowMajorDoubleArray2D) source2).getRowMajorDoubleArray2D(),
((HasRowMajorDoubleArray2D) target).getRowMajorDoubleArray2D());
} else {
calcDenseDoubleMatrix2DMultiThreaded(source1, source2, target);
}
} else {
if (source1 instanceof HasColumnMajorDoubleArray1D
&& source2 instanceof HasColumnMajorDoubleArray1D
&& target instanceof HasColumnMajorDoubleArray1D) {
gemmDoubleArraySingleThreaded(
((HasColumnMajorDoubleArray1D) source1).getColumnMajorDoubleArray1D(),
(int) source1.getRowCount(), (int) source1.getColumnCount(),
((HasColumnMajorDoubleArray1D) source2).getColumnMajorDoubleArray1D(),
(int) source2.getRowCount(), (int) source2.getColumnCount(),
((HasColumnMajorDoubleArray1D) target).getColumnMajorDoubleArray1D());
} else if (source1 instanceof HasRowMajorDoubleArray2D
&& source2 instanceof HasRowMajorDoubleArray2D
&& target instanceof HasRowMajorDoubleArray2D) {
calcDoubleArray2DSingleThreaded(
((HasRowMajorDoubleArray2D) source1).getRowMajorDoubleArray2D(),
((HasRowMajorDoubleArray2D) source2).getRowMajorDoubleArray2D(),
((HasRowMajorDoubleArray2D) target).getRowMajorDoubleArray2D());
} else {
calcDenseDoubleMatrix2DSingleThreaded(source1, source2, target);
}
}
}
private void calcBlockMatrixMultiThreaded(DenseDoubleMatrix2D source1,
DenseDoubleMatrix2D source2, DenseDoubleMatrix2D target) {
BlockDenseDoubleMatrix2D a = null;
BlockDenseDoubleMatrix2D b = null;
BlockDenseDoubleMatrix2D c = null;
if (source1 instanceof BlockDenseDoubleMatrix2D) {
a = (BlockDenseDoubleMatrix2D) source1;
} else {
a = new BlockDenseDoubleMatrix2D(source1);
}
if (source2 instanceof BlockDenseDoubleMatrix2D
&& a.getBlockStripeSize() == ((BlockDenseDoubleMatrix2D) source2)
.getBlockStripeSize()) {
b = (BlockDenseDoubleMatrix2D) source2;
} else {
b = new BlockDenseDoubleMatrix2D(source2, a.getBlockStripeSize(),
BlockOrder.COLUMNMAJOR);
}
final int arows = (int) a.getRowCount();
final int bcols = (int) b.getColumnCount();
if (target instanceof BlockDenseDoubleMatrix2D
&& a.getBlockStripeSize() == ((BlockDenseDoubleMatrix2D) target)
.getBlockStripeSize()) {
c = (BlockDenseDoubleMatrix2D) target;
} else {
c = new BlockDenseDoubleMatrix2D(arows, bcols, a.getBlockStripeSize(),
BlockOrder.ROWMAJOR);
}
// force optimal block order
BlockOrder prevA = a.setBlockOrder(BlockOrder.ROWMAJOR);
BlockOrder prevB = b.setBlockOrder(BlockOrder.COLUMNMAJOR);
blockMultiplyMultiThreaded(a, b, c);
if (c != target) {
for (int j = bcols; --j != -1;) {
for (int i = arows; --i != -1;) {
target.setDouble(c.getDouble(i, j), i, j);
}
}
}
// reset block order
if (Mtimes.RESET_BLOCK_ORDER) {
a.setBlockOrder(prevA);
b.setBlockOrder(prevB);
}
}
private final void gemmDoubleArraySingleThreaded(final double[] A, final int m1RowCount,
final int m1ColumnCount, final double[] B, final int m2RowCount,
final int m2ColumnCount, final double[] C) {
for (int j = 0; j < m2ColumnCount; j++) {
final int jcolTimesM1RowCount = j * m1RowCount;
final int jcolTimesM1ColumnCount = j * m1ColumnCount;
Arrays.fill(C, jcolTimesM1RowCount, jcolTimesM1RowCount + m1RowCount, 0.0d);
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B[lcol + jcolTimesM1ColumnCount];
if (temp != 0.0d) {
final int lcolTimesM1RowCount = lcol * m1RowCount;
calcOneColumn(temp, A, C, m1RowCount, jcolTimesM1RowCount, lcolTimesM1RowCount);
}
}
}
}
private final void calcDoubleArrayMultiThreaded(final double[] A, final int m1RowCount,
final int m1ColumnCount, final double[] B, final int m2RowCount,
final int m2ColumnCount, final double[] C) {
new PFor(0, m2ColumnCount - 1) {
@Override
public void step(int i) {
final int jcolTimesM1RowCount = i * m1RowCount;
final int jcolTimesM1ColumnCount = i * m1ColumnCount;
Arrays.fill(C, jcolTimesM1RowCount, jcolTimesM1RowCount + m1RowCount, 0.0d);
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B[lcol + jcolTimesM1ColumnCount];
if (temp != 0.0d) {
final int lcolTimesM1RowCount = lcol * m1RowCount;
calcOneColumn(temp, A, C, m1RowCount, jcolTimesM1RowCount,
lcolTimesM1RowCount);
}
}
}
};
}
private final static void calcOneColumn(final double temp, final double[] A, final double[] C,
final int m1RowCount, int index1, int index2) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C[index1++] += A[index2++] * temp;
}
}
private final void calcDoubleArray2DSingleThreaded(final double[][] m1, final double[][] m2,
final double[][] ret) {
final int columnCount = m1[0].length;
final double[] columns = new double[columnCount];
for (int c = m2[0].length; --c != -1;) {
for (int k = columnCount; --k != -1;) {
columns[k] = m2[k][c];
}
for (int r = m1.length; --r != -1;) {
double sum = 0.0d;
final double[] row = m1[r];
for (int k = columnCount; --k != -1;) {
sum += row[k] * columns[k];
}
ret[r][c] = sum;
}
}
}
private final void calcDoubleArray2DMultiThreaded(final double[][] m1, final double[][] m2,
final double[][] ret) {
final int columnCount = m1[0].length;
final double[] columns = new double[columnCount];
new PFor(0, m2[0].length - 1) {
@Override
public void step(int i) {
for (int k = columnCount; --k != -1;) {
columns[k] = m2[k][i];
}
for (int r = m1.length; --r != -1;) {
double sum = 0.0d;
final double[] row = m1[r];
for (int k = columnCount; --k != -1;) {
sum += row[k] * columns[k];
}
ret[r][i] = sum;
}
}
};
}
private final void calcDenseDoubleMatrix2DSingleThreaded(final DenseDoubleMatrix2D A,
final DenseDoubleMatrix2D B, final DenseDoubleMatrix2D C) {
final int m1RowCount = (int) A.getRowCount();
final int m1ColumnCount = (int) A.getColumnCount();
final int m2ColumnCount = (int) B.getColumnCount();
for (int i = 0; i < m2ColumnCount; i++) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setDouble(C.getDouble(irow, i) + A.getDouble(irow, lcol) * temp, irow, i);
}
}
}
}
}
private final void calcDenseDoubleMatrix2DMultiThreaded(final DenseDoubleMatrix2D A,
final DenseDoubleMatrix2D B, final DenseDoubleMatrix2D C) {
final int m1RowCount = (int) A.getRowCount();
final int m1ColumnCount = (int) A.getColumnCount();
final int m2ColumnCount = (int) B.getColumnCount();
new PFor(0, m2ColumnCount - 1) {
@Override
public void step(int i) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setDouble(0.0d, irow, i);
}
for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
final double temp = B.getDouble(lcol, i);
if (temp != 0.0d) {
for (int irow = 0; irow < m1RowCount; ++irow) {
C.setDouble(C.getDouble(irow, i) + A.getDouble(irow, lcol) * temp,
irow, i);
}
}
}
}
};
}
/**
* Multiply two matrices concurrently with the given Executor to handle
* parallel tasks.
*
* @param b
* - matrix to multiply this with.
* @param executorService
* - to handle concurrent multiplication tasks.
* @return new matrix C containing result of matrix multiplication C = A x
* B.
*/
/**
* @param a
* @param b
* @param c
* @return
*/
private BlockDenseDoubleMatrix2D blockMultiplyMultiThreaded(final BlockDenseDoubleMatrix2D a,
final BlockDenseDoubleMatrix2D b, final BlockDenseDoubleMatrix2D c) {
final BlockMatrixLayout al = a.getBlockLayout();
final BlockMatrixLayout bl = b.getBlockLayout();
verifyTrue(al.columns == bl.rows, "b.rows != this.columns");
verifyTrue(al.blockStripe == bl.blockStripe, "block sizes differ: %s != %s",
al.blockStripe, bl.blockStripe);
final List<Callable<Void>> tasks = new LinkedList<Callable<Void>>();
final int kMax = (int) b.getColumnCount();
final int jMax = (int) a.getColumnCount();
final int iMax = (int) a.getRowCount();
final int bColSlice = Math.min(al.blockStripe, kMax);
final int aColSlice = Math.min(al.blockStripe, jMax);
final int aRowSlice = Math.min(al.blockStripe, iMax);
// Number of blocks to take for each concurrent task.
final int blocksPerTask = 1;
final int blocksPerTaskDimJ = selectBlocksPerTaskDimJ(al.blockStripe, iMax, jMax, kMax);
for (int k = 0, kStride; k < kMax; k += kStride) {
kStride = Math.min(blocksPerTask * bColSlice, kMax - k);
for (int j = 0, jStride; j < jMax; j += jStride) {
jStride = Math.min(blocksPerTaskDimJ * aColSlice, jMax - j);
for (int i = 0, iStride; i < iMax; i += iStride) {
iStride = Math.min(blocksPerTask * aRowSlice, iMax - i);
tasks.add(new BlockMultiply(a, b, c, i, (i + iStride), j, (j + jStride), k,
(k + kStride)));
}
}
}
// wait for all tasks to complete.
try {
for (Future<Void> f : UJMPThreadPoolExecutor.getInstance().invokeAll(tasks)) {
f.get();
}
} catch (ExecutionException e) {
StringBuilder sb = new StringBuilder(
"Execution exception - while awaiting completion of matrix multiplication ["
+ e.getMessage() + "]:");
if (e.getCause() != null) {
for (StackTraceElement stackTraceElement : e.getCause().getStackTrace()) {
sb.append(stackTraceElement).append(" * ");
}
}
throw new RuntimeException(sb.toString(), e.getCause());
} catch (final InterruptedException e) {
String msg = "Interrupted - while awaiting completion of matrix multiplication.";
throw new RuntimeException(msg + ": cause [" + e.getMessage() + "]", e);
}
return c;
}
// pick a suitable number of blocks to process per task for dimension J
// - if too small , then incurs extra gc and contention for synchronization
// - if set too large, then may not fully exploit parallelism
private int selectBlocksPerTaskDimJ(int blockStripe, int iMax, int jMax, int kMax) {
int adjust = (jMax % blockStripe > 0) ? 1 : 0;
if (jMax < (5 * blockStripe) || jMax <= iMax) {
// do not break this dimension into parallel tasks
return jMax / blockStripe + adjust;
} else {
// assume 2 parallel tasks
return Math.max(1, (jMax / blockStripe + adjust) / 2);
}
// may need something if jMax >>> iMax
}
};