/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.runtime.matrix.data; import java.util.ArrayList; import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import org.apache.commons.math3.util.FastMath; import com.ibm.bi.dml.lops.MapMultChain.ChainType; import com.ibm.bi.dml.lops.WeightedCrossEntropy.WCeMMType; import com.ibm.bi.dml.lops.WeightedDivMM.WDivMMType; import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType; import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType; import com.ibm.bi.dml.runtime.DMLRuntimeException; import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException; import com.ibm.bi.dml.runtime.functionobjects.SwapIndex; import com.ibm.bi.dml.runtime.matrix.operators.ReorgOperator; import com.ibm.bi.dml.runtime.util.UtilFunctions; /** * MB: * Library for matrix multiplications including MM, MV, VV for all * combinations of dense, sparse, ultrasparse representations and special * operations such as transpose-self matrix multiplication. * * In general all implementations use internally dense outputs * for direct access, but change the final result to sparse if necessary. * The only exceptions are ultra-sparse matrix mult, wsloss and wsigmoid. * * NOTES on BLAS: * * Experiments in 04/2013 showed that even on dense-dense this implementation * is 3x faster than f2j-BLAS-DGEMM, 2x faster than f2c-BLAS-DGEMM, and * level (+10% after JIT) with a native C implementation. * * Calling native BLAS would loose platform independence and would require * JNI calls incl data transfer. Furthermore, BLAS does not support sparse * matrices (except Sparse BLAS, with dedicated function calls and matrix formats) * and would be an external dependency. * * Experiments in 02/2014 showed that on dense-dense this implementation now achieves * almost 30% peak FP performance. Compared to Intel MKL 11.1 (dgemm, N=1000) it is * just 3.2x (sparsity=1.0) and 1.9x (sparsity=0.5) slower, respectively. * */ public class LibMatrixMult { //internal configuration public static final boolean LOW_LEVEL_OPTIMIZATION = true; public static final long MEM_OVERHEAD_THRESHOLD = 2L*1024*1024; //MAX 2 MB private static final long PAR_MINFLOP_THRESHOLD = 2L*1024*1024; //MIN 2 MFLOP private LibMatrixMult() { //prevent instantiation via private constructor } //////////////////////////////// // public matrix mult interface //////////////////////////////// /** * Performs a matrix multiplication and stores the result in the output matrix. * * All variants use a IKJ access pattern, and internally use dense output. After the * actual computation, we recompute nnz and check for sparse/dense representation. * * * @param m1 first matrix * @param m2 second matrix * @param ret result matrix * @throws DMLRuntimeException */ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret) throws DMLRuntimeException { //check inputs / outputs if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing: output allocation boolean tm2 = checkPrepMatrixMultRightInput(m1,m2); m2 = prepMatrixMultRightInput(m1, m2); ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse()); if( !ret.sparse ) ret.allocateDenseBlock(); //prepare row-upper for special cases of vector-matrix boolean pm2 = checkParMatrixMultRightInput(m1, m2, Integer.MAX_VALUE); int ru = pm2 ? m2.rlen : m1.rlen; //core matrix mult computation if( m1.isUltraSparse() || m2.isUltraSparse() ) matrixMultUltraSparse(m1, m2, ret, 0, ru); else if(!m1.sparse && !m2.sparse) matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru); else if(m1.sparse && m2.sparse) matrixMultSparseSparse(m1, m2, ret, pm2, 0, ru); else if(m1.sparse) matrixMultSparseDense(m1, m2, ret, pm2, 0, ru); else matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru); //post-processing: nnz/representation if( !ret.sparse ) ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("MM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); } /** * Performs a multi-threaded matrix multiplication and stores the result in the output matrix. * The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen). * * @param m1 * @param m2 * @param ret * @param k * @throws DMLRuntimeException */ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) throws DMLRuntimeException { //check inputs / outputs if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //check too high additional vector-matrix memory requirements (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) if( m1.rlen == 1 && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD || !LOW_LEVEL_OPTIMIZATION || m2.clen==1 || m1.isUltraSparse() || m2.isUltraSparse()) || 2L * m1.rlen * m1.clen * m2.clen < PAR_MINFLOP_THRESHOLD ) { matrixMult(m1, m2, ret); return; } //Timing time = new Timing(true); //pre-processing: output allocation (in contrast to single-threaded, //we need to allocate sparse as well in order to prevent synchronization) boolean tm2 = checkPrepMatrixMultRightInput(m1,m2); m2 = prepMatrixMultRightInput(m1, m2); ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse()); if( !ret.sparse ) ret.allocateDenseBlock(); else ret.allocateSparseRowsBlock(); //prepare row-upper for special cases of vector-matrix / matrix-matrix boolean pm2 = checkParMatrixMultRightInput(m1, m2, k); int ru = pm2 ? m2.rlen : m1.rlen; //core multi-threaded matrix mult computation //(currently: always parallelization over number of rows) try { ExecutorService pool = Executors.newFixedThreadPool( k ); ArrayList<MatrixMultTask> tasks = new ArrayList<MatrixMultTask>(); int blklen = (int)(Math.ceil((double)ru/k)); for( int i=0; i<k & i*blklen<ru; i++ ) tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2, i*blklen, Math.min((i+1)*blklen, ru))); pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results (nnz, ret for vector/matrix) ret.nonZeros = 0; //reset after execute for( MatrixMultTask task : tasks ) { if( pm2 ) vectAdd(task.getResult().denseBlock, ret.denseBlock, 0, 0, ret.rlen*ret.clen); else ret.nonZeros += task.getPartialNnz(); } if( pm2 ) ret.recomputeNonZeros(); } catch(Exception ex) { throw new DMLRuntimeException(ex); } //post-processing (nnz maintained in parallel) ret.examSparsity(); //System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); } /** * Performs a matrix multiplication chain operation of type t(X)%*%(X%*%v) or t(X)%*%(w*(X%*%v)). * * All variants use a IKJ access pattern, and internally use dense output. After the * actual computation, we recompute nnz and check for sparse/dense representation. * * @param m1 * @param m2 * @param w * @param ret * @param ct * @throws DMLRuntimeException * @throws DMLUnsupportedOperationException */ public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct) throws DMLRuntimeException, DMLUnsupportedOperationException { //check inputs / outputs (after that mV and mW guaranteed to be dense) if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing: output allocation ret.sparse = false; ret.allocateDenseBlock(); //core matrix mult chain computation if( mX.sparse ) matrixMultChainSparse(mX, mV, mW, ret, ct, 0, mX.rlen); else matrixMultChainDense(mX, mV, mW, ret, ct, 0, mX.rlen); //post-processing ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("MMChain "+ct.toString()+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * Performs a parallel matrix multiplication chain operation of type t(X)%*%(X%*%v) or t(X)%*%(w*(X%*%v)). * The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen). * * NOTE: This multi-threaded mmchain operation has additional memory requirements of k*ncol(X)*8bytes * for partial aggregation. Current max memory: 256KB; otherwise redirectly to sequential execution. * * @param mX * @param mV * @param mW * @param ret * @param ct * @param k * @throws DMLRuntimeException * @throws DMLUnsupportedOperationException */ public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int k) throws DMLRuntimeException, DMLUnsupportedOperationException { //check inputs / outputs (after that mV and mW guaranteed to be dense) if( mX.isEmptyBlock(false) || mV.isEmptyBlock(false) || (mW !=null && mW.isEmptyBlock(false)) ) { ret.examSparsity(); //turn empty dense into sparse return; } //check too high additional memory requirements (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) if( 8L * mV.rlen * k > MEM_OVERHEAD_THRESHOLD || 4L * mX.rlen * mX.clen < PAR_MINFLOP_THRESHOLD ) { matrixMultChain(mX, mV, mW, ret, ct); return; } //Timing time = new Timing(true); //pre-processing ret.sparse = false; ret.allocateDenseBlock(); //core matrix mult chain computation //(currently: always parallelization over number of rows) try { ExecutorService pool = Executors.newFixedThreadPool( k ); ArrayList<MatrixMultChainTask> tasks = new ArrayList<MatrixMultChainTask>(); int blklen = (int)(Math.ceil((double)mX.rlen/k)); blklen += (blklen%24 != 0)?24-blklen%24:0; for( int i=0; i<k & i*blklen<mX.rlen; i++ ) tasks.add(new MatrixMultChainTask(mX, mV, mW, ret, ct, i*blklen, Math.min((i+1)*blklen, mX.rlen))); pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results for( MatrixMultChainTask task : tasks ) vectAdd(task.getResult().denseBlock, ret.denseBlock, 0, 0, mX.clen); } catch(Exception ex) { throw new DMLRuntimeException(ex); } //post-processing ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("MMChain "+ct.toString()+" k="+k+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * * @param m1 * @param ret * @param leftTranspose * @throws DMLRuntimeException * @throws DMLUnsupportedOperationException */ public static void matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose ) throws DMLUnsupportedOperationException, DMLRuntimeException { //check inputs / outputs if( m1.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing m1 = prepMatrixMultTransposeSelfInput(m1, leftTranspose); ret.sparse = false; ret.allocateDenseBlock(); if( m1.sparse ) matrixMultTransposeSelfSparse(m1, ret, leftTranspose, 0, ret.rlen); else matrixMultTransposeSelfDense(m1, ret, leftTranspose, 0, ret.rlen ); //post-processing copyUpperToLowerTriangle( ret ); ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("TSMM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+","+leftTranspose+") in "+time.stop()); } /** * * @param m1 * @param ret * @param leftTranspose * @param k * @throws DMLUnsupportedOperationException * @throws DMLRuntimeException */ public static void matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int k ) throws DMLUnsupportedOperationException, DMLRuntimeException { //check inputs / outputs if( m1.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //check no parallelization benefit (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) if( ret.rlen == 1 || leftTranspose && 1L * m1.rlen * m1.clen * m1.clen < PAR_MINFLOP_THRESHOLD || !leftTranspose && 1L * m1.clen * m1.rlen * m1.rlen < PAR_MINFLOP_THRESHOLD ) { matrixMultTransposeSelf(m1, ret, leftTranspose); return; } //Timing time = new Timing(true); //pre-processing m1 = prepMatrixMultTransposeSelfInput(m1, leftTranspose); ret.sparse = false; ret.allocateDenseBlock(); //core multi-threaded matrix mult computation try { ExecutorService pool = Executors.newFixedThreadPool( k ); ArrayList<MatrixMultTransposeTask> tasks = new ArrayList<MatrixMultTransposeTask>(); //load balance via #tasks=2k due to triangular shape int blklen = (int)(Math.ceil((double)ret.rlen/(2*k))); for( int i=0; i<2*k & i*blklen<ret.rlen; i++ ) tasks.add(new MatrixMultTransposeTask(m1, ret, leftTranspose, i*blklen, Math.min((i+1)*blklen, ret.rlen))); pool.invokeAll(tasks); pool.shutdown(); } catch(Exception ex) { throw new DMLRuntimeException(ex); } //post-processing copyUpperToLowerTriangle( ret ); ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("TSMM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+","+leftTranspose+") in "+time.stop()); } /** * * @param m1 * @param m2 * @param ret1 * @param ret2 * @throws DMLUnsupportedOperationException * @throws DMLRuntimeException */ public static void matrixMultPermute( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2 ) throws DMLUnsupportedOperationException, DMLRuntimeException { //check inputs / outputs if( pm1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) return; //Timing time = new Timing(true); //pre-processing ret1.sparse = (m2.sparse || ret1.sparse); if( ret1.sparse ) ret1.allocateSparseRowsBlock(); else ret1.allocateDenseBlock(); //core permutation mm computation if( m2.sparse ) matrixMultPermuteSparse(pm1, m2, ret1, ret2, 0, pm1.rlen); else if( ret1.sparse ) matrixMultPermuteDenseSparse(pm1, m2, ret1, ret2, 0, pm1.rlen); else matrixMultPermuteDense(pm1, m2, ret1, ret2, 0, pm1.rlen); //post-processing ret1.recomputeNonZeros(); ret1.examSparsity(); if( ret2 != null ) { //optional second output ret2.recomputeNonZeros(); ret2.examSparsity(); } //System.out.println("PMM Seq ("+pm1.isInSparseFormat()+","+pm1.getNumRows()+","+pm1.getNumColumns()+","+pm1.getNonZeros()+")x" + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); } /** * * @param m1 * @param m2 * @param ret1 * @param ret2 * @throws DMLUnsupportedOperationException * @throws DMLRuntimeException * @throws DMLRuntimeException */ public static void matrixMultPermute( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int k) throws DMLUnsupportedOperationException, DMLRuntimeException { //check inputs / outputs if( pm1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) return; //check no parallelization benefit (fallback to sequential) if (pm1.rlen == 1) { matrixMultPermute(pm1, m2, ret1, ret2); return; } //Timing time = new Timing(true); //allocate first output block (second allocated if needed) ret1.sparse = false; ret1.allocateDenseBlock(); try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<MatrixMultPermuteTask> tasks = new ArrayList<MatrixMultPermuteTask>(); int blklen = (int)(Math.ceil((double)pm1.rlen/k)); for( int i=0; i<k & i*blklen<pm1.rlen; i++ ) tasks.add(new MatrixMultPermuteTask(pm1, m2, ret1, ret2, i*blklen, Math.min((i+1)*blklen, pm1.rlen))); pool.invokeAll(tasks); pool.shutdown(); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } //post-processing ret1.recomputeNonZeros(); ret1.examSparsity(); if( ret2 != null ) { //optional second output ret2.recomputeNonZeros(); ret2.examSparsity(); } // System.out.println("PMM Par ("+pm1.isInSparseFormat()+","+pm1.getNumRows()+","+pm1.getNumColumns()+","+pm1.getNonZeros()+")x" + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @throws DMLRuntimeException */ public static void matrixMultWSLoss(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, WeightsType wt) throws DMLRuntimeException { //check for empty result if( wt==WeightsType.POST && mW.isEmptyBlock(false) || wt==WeightsType.POST_NZ && mX.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //core weighted square sum mm computation if( !mX.sparse && !mU.sparse && !mV.sparse && (mW==null || !mW.sparse) && !mX.isEmptyBlock() && !mU.isEmptyBlock() && !mV.isEmptyBlock() && (mW==null || !mW.isEmptyBlock())) matrixMultWSLossDense(mX, mU, mV, mW, ret, wt, 0, mX.rlen); else if( mX.sparse && !mU.sparse && !mV.sparse && (mW==null || mW.sparse) && !mX.isEmptyBlock() && !mU.isEmptyBlock() && !mV.isEmptyBlock() && (mW==null || !mW.isEmptyBlock())) matrixMultWSLossSparseDense(mX, mU, mV, mW, ret, wt, 0, mX.rlen); else matrixMultWSLossGeneric(mX, mU, mV, mW, ret, wt, 0, mX.rlen); //System.out.println("MMWSLoss " +wt.toString()+ " ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @throws DMLRuntimeException */ public static void matrixMultWSLoss(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, WeightsType wt, int k) throws DMLRuntimeException { //check for empty result if( wt==WeightsType.POST && mW.isEmptyBlock(false) || wt==WeightsType.POST_NZ && mX.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //check no parallelization benefit (fallback to sequential) if (mX.rlen == 1) { matrixMultWSLoss(mX, mU, mV, mW, ret, wt); return; } //Timing time = new Timing(true); try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<ScalarResultTask> tasks = new ArrayList<ScalarResultTask>(); int blklen = (int)(Math.ceil((double)mX.rlen/k)); for( int i=0; i<k & i*blklen<mX.rlen; i++ ) tasks.add(new MatrixMultWSLossTask(mX, mU, mV, mW, wt, i*blklen, Math.min((i+1)*blklen, mX.rlen))); pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results sumScalarResults(tasks, ret); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } //System.out.println("MMWSLoss "+wt.toString()+" k="+k+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @throws DMLRuntimeException */ public static void matrixMultWSigmoid(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt) throws DMLRuntimeException { //check for empty result if( mW.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing ret.sparse = mW.sparse; ret.allocateDenseOrSparseBlock(); //core weighted square sum mm computation if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() ) matrixMultWSigmoidDense(mW, mU, mV, ret, wt, 0, mW.rlen); else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock()) matrixMultWSigmoidSparseDense(mW, mU, mV, ret, wt, 0, mW.rlen); else matrixMultWSigmoidGeneric(mW, mU, mV, ret, wt, 0, mW.rlen); //post-processing ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("MMWSig "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @param k * @throws DMLRuntimeException */ public static void matrixMultWSigmoid(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int k) throws DMLRuntimeException { //check for empty result if( mW.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //check no parallelization benefit (fallback to sequential) if (mW.rlen == 1) { matrixMultWSigmoid(mW, mU, mV, ret, wt); return; } //Timing time = new Timing(true); //pre-processing ret.sparse = mW.sparse; ret.allocateDenseOrSparseBlock(); try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<MatrixMultWSigmoidTask> tasks = new ArrayList<MatrixMultWSigmoidTask>(); int blklen = (int)(Math.ceil((double)mW.rlen/k)); for( int i=0; i<k & i*blklen<mW.rlen; i++ ) tasks.add(new MatrixMultWSigmoidTask(mW, mU, mV, ret, wt, i*blklen, Math.min((i+1)*blklen, mW.rlen))); pool.invokeAll(tasks); pool.shutdown(); ret.nonZeros = 0; //reset after execute for( MatrixMultWSigmoidTask task : tasks ) ret.nonZeros += task.getPartialNnz(); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } //post-processing (nnz maintained in parallel) ret.examSparsity(); //System.out.println("MMWSig "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop() + "."); } /** * NOTE: This operation has limited NaN support, which is acceptable because all our sparse-safe operations * have only limited NaN support. If this is not intended behavior, please disable the rewrite. In detail, * this operator will produce for W/(U%*%t(V)) a zero intermediate for each zero in W (even if UVij is zero * which would give 0/0=NaN) but INF/-INF for non-zero entries in V where the corresponding cell in (Y%*%X) * is zero. * * @param mX * @param mU * @param mV * @param ret * @param wt * @throws DMLRuntimeException */ public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt) throws DMLRuntimeException { //check for empty result if( mW.isEmptyBlock(false) || (wt.isLeft() && mU.isEmptyBlock(false)) || (wt.isRight() && mV.isEmptyBlock(false)) || (wt.isBasic() && mW.isEmptyBlock(false))) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing ret.sparse = wt.isBasic()?mW.sparse:false; ret.allocateDenseOrSparseBlock(); //core weighted div mm computation if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() ) matrixMultWDivMMDense(mW, mU, mV, ret, wt, 0, mW.rlen, 0, mW.clen); else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock()) matrixMultWDivMMSparseDense(mW, mU, mV, ret, wt, 0, mW.rlen, 0, mW.clen); else matrixMultWDivMMGeneric(mW, mU, mV, ret, wt, 0, mW.rlen, 0, mW.clen); //post-processing ret.recomputeNonZeros(); ret.examSparsity(); //System.out.println("MMWDiv "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * NOTE: This operation has limited NaN support, which is acceptable because all our sparse-safe operations * have only limited NaN support. If this is not intended behavior, please disable the rewrite. In detail, * this operator will produce for W/(U%*%t(V)) a zero intermediate for each zero in W (even if UVij is zero * which would give 0/0=NaN) but INF/-INF for non-zero entries in V where the corresponding cell in (Y%*%X) * is zero. * * @param mX * @param mU * @param mV * @param ret * @param wt * @param k * @throws DMLRuntimeException */ public static void matrixMultWDivMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt, int k) throws DMLRuntimeException { //check for empty result if( mW.isEmptyBlock(false) || (wt.isLeft() && mU.isEmptyBlock(false)) || (wt.isRight() && mV.isEmptyBlock(false)) || (wt.isBasic() && mW.isEmptyBlock(false))) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing ret.sparse = wt.isBasic()?mW.sparse:false; ret.allocateDenseOrSparseBlock(); try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<MatrixMultWDivTask> tasks = new ArrayList<MatrixMultWDivTask>(); //create tasks (for wdivmm-left, parallelization over columns; //for wdivmm-right, parallelization over rows; both ensure disjoint results) if( wt.isLeft() ) { int blklen = (int)(Math.ceil((double)mW.clen/k)); for( int j=0; j<k & j*blklen<mW.clen; j++ ) tasks.add(new MatrixMultWDivTask(mW, mU, mV, ret, wt, 0, mW.rlen, j*blklen, Math.min((j+1)*blklen, mW.clen))); } else { //basic/right int blklen = (int)(Math.ceil((double)mW.rlen/k)); for( int i=0; i<k & i*blklen<mW.rlen; i++ ) tasks.add(new MatrixMultWDivTask(mW, mU, mV, ret, wt, i*blklen, Math.min((i+1)*blklen, mW.rlen), 0, mW.clen)); } //execute tasks pool.invokeAll(tasks); pool.shutdown(); //aggregate partial nnz for( MatrixMultWDivTask task : tasks ) ret.nonZeros += task.getPartialNnz(); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } //post-processing ret.examSparsity(); //System.out.println("MMWDiv "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @throws DMLRuntimeException */ public static void matrixMultWCeMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WCeMMType wt) throws DMLRuntimeException { //check for empty result if( mW.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing ret.sparse = false; ret.allocateDenseBlock(); //core weighted div mm computation if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() ) matrixMultWCeMMDense(mW, mU, mV, ret, wt, 0, mW.rlen); else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock()) matrixMultWCeMMSparseDense(mW, mU, mV, ret, wt, 0, mW.rlen); else matrixMultWCeMMGeneric(mW, mU, mV, ret, wt, 0, mW.rlen); //System.out.println("MMWCe "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } /** * * @param mX * @param mU * @param mV * @param ret * @param wt * @param k * @throws DMLRuntimeException */ public static void matrixMultWCeMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WCeMMType wt, int k) throws DMLRuntimeException { //check for empty result if( mW.isEmptyBlock(false) ) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing ret.sparse = false; ret.allocateDenseBlock(); try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<ScalarResultTask> tasks = new ArrayList<ScalarResultTask>(); int blklen = (int)(Math.ceil((double)mW.rlen/k)); for( int i=0; i<k & i*blklen<mW.rlen; i++ ) tasks.add(new MatrixMultWCeTask(mW, mU, mV, wt, i*blklen, Math.min((i+1)*blklen, mW.rlen))); pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results sumScalarResults(tasks, ret); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } //System.out.println("MMWCe "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); } ////////////////////////////////////////// // optimized matrix mult implementation // ////////////////////////////////////////// /** * * @param m1 * @param m2 * @param ret * @throws DMLRuntimeException */ private static void matrixMultDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean tm2, boolean pm2, int rl, int ru) throws DMLRuntimeException { double[] a = m1.denseBlock; double[] b = m2.denseBlock; double[] c = ret.denseBlock; final int m = m1.rlen; final int n = m2.clen; final int cd = m1.clen; if( LOW_LEVEL_OPTIMIZATION ) { if( m==1 && n==1 ) //DOT PRODUCT { c[0] = dotProduct(a, b, cd); } else if( n>1 && cd == 1 ) //OUTER PRODUCT { for( int i=rl, cix=rl*n; i < ru; i++, cix+=n) { if( a[i] == 1 ) System.arraycopy(b, 0, c, cix, n); else if( a[i] != 0 ) vectMultiplyWrite(a[i], b, c, 0, cix, n); else Arrays.fill(c, cix, cix+n, 0); } } else if( n==1 && cd == 1 ) //VECTOR-SCALAR { vectMultiplyWrite(b[0], a, c, rl, rl, ru-rl); } else if( n==1 ) //MATRIX-VECTOR { for( int i=rl, aix=rl*cd; i < ru; i++, aix+=cd) c[ i ] = dotProduct(a, b, aix, 0, cd); } else if( pm2 && m==1 ) //VECTOR-MATRIX { //parallelization over rows in rhs matrix //rest not aligned to blocks of 2 rows final int kn = (ru-rl)%2; if( kn == 1 && a[rl] != 0 ) vectMultiplyAdd(a[rl], b, c, rl*n, 0, n); //compute blocks of 2 rows (2 instead of 4 for small n<64) for( int k=rl+kn, bix=(rl+kn)*n; k<ru; k+=2, bix+=2*n ){ if( a[k] != 0 && a[k+1] != 0 ) vectMultiplyAdd2(a[k], a[k+1], b, c, bix, bix+n, 0, n); else if( a[k] != 0 ) vectMultiplyAdd(a[k], b, c, bix, 0, n); else if( a[k+1] != 0 ) vectMultiplyAdd(a[k+1], b, c, bix+n, 0, n); } } else if( pm2 && m<=16 ) //MATRIX-MATRIX (short lhs) { //parallelization over rows in rhs matrix final int kn = (ru-rl)%2; //rest not aligned to blocks of 2 rows if( kn == 1 ) for( int i=0, aix=0, cix=0; i<m; i++, aix+=cd, cix+=n ) if( a[aix+rl] != 0 ) vectMultiplyAdd(a[aix+rl], b, c, rl*n, cix, n); //compute blocks of 2 rows (w/ repeated scan for each row in lhs) for( int k=rl+kn, bix=(rl+kn)*n; k<ru; k+=2, bix+=2*n ) for( int i=0, aix=0, cix=0; i<m; i++, aix+=cd, cix+=n ){ if( a[aix+k] != 0 && a[aix+k+1] != 0 ) vectMultiplyAdd2(a[aix+k], a[aix+k+1], b, c, bix, bix+n, cix, n); else if( a[aix+k] != 0 ) vectMultiplyAdd(a[aix+k], b, c, bix, cix, n); else if( a[aix+k+1] != 0 ) vectMultiplyAdd(a[aix+k+1], b, c, bix+n, cix, n); } } else if( tm2 ) //MATRIX-MATRIX (skinny rhs) { //note: prepared rhs input via transpose for: m > n && cd > 64 && n < 64 //however, explicit flag required since dimension change m2 final int n2 = m2.rlen; for( int i=rl, aix=rl*cd, cix=rl*n2; i < ru; i++, aix+=cd, cix+=n2 ) for( int j=0, bix=0; j<n2; j++, bix+=cd ) c[cix+j] = dotProduct(a, b, aix, bix, cd); } else //MATRIX-MATRIX { //1) Unrolled inner loop (for better instruction-level parallelism) //2) Blocked execution (for less cache trashing in parallel exec) //3) Asymmetric block sizes (for less misses in inner loop, yet blocks in L1/L2) final int blocksizeI = 32; //64//256KB c block (typical L2 size per core), 32KB a block final int blocksizeK = 24; //64//256KB b block (typical L2 size per core), used while read 512B of a / read/write 4KB of c final int blocksizeJ = 1024; //512//4KB (typical main-memory page size), for scan //temporary arrays (nnz a, b index) double[] ta = new double[ blocksizeK ]; int[] tbi = new int[ blocksizeK ]; //blocked execution for( int bi = rl; bi < ru; bi+=blocksizeI ) for( int bk = 0, bimin = Math.min(ru, bi+blocksizeI); bk < cd; bk+=blocksizeK ) for( int bj = 0, bkmin = Math.min(cd, bk+blocksizeK); bj < n; bj+=blocksizeJ ) { int bklen = bkmin-bk; int bjlen = Math.min(n, bj+blocksizeJ)-bj; //core sub block matrix multiplication for( int i = bi; i < bimin; i++) { int aixi = i * cd + bk; //start index on a int cixj = i * n + bj; //scan index on c //determine nnz of a (for sparsity-aware skipping of rows) int knnz = copyNonZeroElements(a, aixi, bk, bj, n, ta, tbi, bklen); //if( knnz > 0 ) //for skipping empty rows //rest not aligned to blocks of 4 rows final int bn = knnz % 4; switch( bn ){ case 1: vectMultiplyAdd(ta[0], b, c, tbi[0], cixj, bjlen); break; case 2: vectMultiplyAdd2(ta[0],ta[1], b, c, tbi[0], tbi[1], cixj, bjlen); break; case 3: vectMultiplyAdd3(ta[0],ta[1],ta[2], b, c, tbi[0], tbi[1],tbi[2], cixj, bjlen); break; } //compute blocks of 4 rows (core inner loop) for( int k = bn; k<knnz; k+=4 ){ vectMultiplyAdd4( ta[k], ta[k+1], ta[k+2], ta[k+3], b, c, tbi[k], tbi[k+1], tbi[k+2], tbi[k+3], cixj, bjlen ); } } } } } else { double val; for( int i = rl, aix=rl*cd, cix=rl*n; i < ru; i++, cix+=n) for( int k = 0, bix=0; k < cd; k++, aix++, bix+=n) { val = a[ aix ]; if( val != 0 ) for( int j = 0; j < n; j++) c[ cix+j ] += val * b[ bix+j ]; } } } /** * * @param m1 * @param m2 * @param ret * @throws DMLRuntimeException */ private static void matrixMultDenseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, int ru) throws DMLRuntimeException { double[] a = m1.denseBlock; double[] c = ret.denseBlock; int m = m1.rlen; int cd = m1.clen; int n = m2.clen; // MATRIX-MATRIX (VV, MV not applicable here because V always dense) if( LOW_LEVEL_OPTIMIZATION ) { final int blocksizeI = 32; //256KB c block (typical L2 size per core), 32KB a block final int blocksizeK = 32; //note: in contrast to dense-dense, no blocking over j (would require maintaining blocksizeK indexes, counter-productive on skew) SparseRow[] b = m2.sparseRows; if( pm2 && m==1 ) //VECTOR-MATRIX { //parallelization over rows in rhs matrix for( int k=rl; k<ru; k++ ) if( a[k] != 0 && b[k] != null && !b[k].isEmpty() ) { int[] bix = b[k].getIndexContainer(); double[] bvals = b[k].getValueContainer(); vectMultiplyAdd(a[k], bvals, c, bix, 0, b[k].size()); } } else //MATRIX-MATRIX { //blocked execution for( int bi = rl; bi < ru; bi+=blocksizeI ) for( int bk = 0, bimin = Math.min(ru, bi+blocksizeI); bk < cd; bk+=blocksizeK ) { int bklen = Math.min(cd, bk+blocksizeK)-bk; //core sub block matrix multiplication for( int i = bi; i < bimin; i++) { int aixi = i * cd + bk; //start index on a int cixj = i * n + 0; //scan index on c for( int k = 0; k < bklen; k++ ) { double val = a[aixi+k]; SparseRow brow = b[ bk+k ]; if( val != 0 && brow != null && !brow.isEmpty() ) { int blen = brow.size(); int[] bix = brow.getIndexContainer(); double[] bvals = brow.getValueContainer(); vectMultiplyAdd(val, bvals, c, bix, cixj, blen); } } } } } } else { for( int i=rl, aix=rl*cd, cix=rl*n; i < ru; i++, cix+=n ) for(int k = 0; k < cd; k++, aix++ ) { double val = a[aix]; if( val!=0 ) { SparseRow brow = m2.sparseRows[ k ]; if( brow != null && !brow.isEmpty() ) { int blen = brow.size(); int[] bix = brow.getIndexContainer(); double[] bvals = brow.getValueContainer(); for(int j = 0; j < blen; j++) c[cix+bix[j]] += val * bvals[j]; } } } } } /** * * @param m1 * @param m2 * @param ret * @throws DMLRuntimeException */ private static void matrixMultSparseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, int ru) throws DMLRuntimeException { double[] b = m2.denseBlock; double[] c = ret.denseBlock; final int m = m1.rlen; final int n = m2.clen; if( LOW_LEVEL_OPTIMIZATION ) { if( m==1 && n==1 ) //DOT PRODUCT { SparseRow arow = m1.sparseRows[0]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); c[0] = dotProduct(avals, b, aix, 0, alen); } } else if( n==1 ) //MATRIX-VECTOR { for( int i=rl; i<Math.min(ru, m1.sparseRows.length); i++ ) { SparseRow arow = m1.sparseRows[i]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); c[i] = dotProduct(avals, b, aix, 0, alen); } } } else if( pm2 && m==1 ) //VECTOR-MATRIX { //parallelization over rows in rhs matrix SparseRow arow = m1.sparseRows[0]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); int rlix = (rl==0) ? 0 : arow.searchIndexesFirstGTE(rl); rlix = (rlix>=0) ? rlix : alen; for( int k=rlix; k<alen && aix[k]<ru; k++ ) { if( k+1<alen && aix[k+1]<ru ) vectMultiplyAdd2(avals[k], avals[k+1], b, c, aix[k]*n, aix[++k]*n, 0, n); else vectMultiplyAdd(avals[k], b, c, aix[k]*n, 0, n); } } } else //MATRIX-MATRIX { for( int i=rl, cix=rl*n; i<Math.min(ru, m1.sparseRows.length); i++, cix+=n ) { SparseRow arow = m1.sparseRows[i]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); if( alen==1 && avals[0]==1 ) //ROW SELECTION { //plain memcopy for permutation matrices System.arraycopy(b, aix[0]*n, c, cix, n); } else //GENERAL CASE { //rest not aligned to blocks of 4 rows final int bn = alen % 4; switch( bn ){ case 1: vectMultiplyAdd(avals[0], b, c, aix[0]*n, cix, n); break; case 2: vectMultiplyAdd2(avals[0],avals[1], b, c, aix[0]*n, aix[1]*n, cix, n); break; case 3: vectMultiplyAdd3(avals[0],avals[1],avals[2], b, c, aix[0]*n, aix[1]*n, aix[2]*n, cix, n); break; } //compute blocks of 4 rows (core inner loop) for( int k = bn; k<alen; k+=4 ) { vectMultiplyAdd4( avals[k], avals[k+1], avals[k+2], avals[k+3], b, c, aix[k]*n, aix[k+1]*n, aix[k+2]*n, aix[k+3]*n, cix, n ); } } } } } } else { for( int i=rl, cix=rl*n; i<Math.min(ru, m1.sparseRows.length); i++, cix+=n ) { SparseRow arow = m1.sparseRows[i]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); for(int k = 0; k < alen; k++) { double val = avals[k]; for(int j = 0, bix=aix[k]*n; j < n; j++) c[cix+j] += val * b[bix+j]; } } } } } /** * * @param m1 * @param m2 * @param ret * @throws DMLRuntimeException */ private static void matrixMultSparseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, int ru) throws DMLRuntimeException { SparseRow[] b = m2.sparseRows; double[] c = ret.denseBlock; int m = m1.rlen; int n = m2.clen; // MATRIX-MATRIX (VV, MV not applicable here because V always dense) if(LOW_LEVEL_OPTIMIZATION) { if( pm2 && m==1 ) //VECTOR-MATRIX { //parallelization over rows in rhs matrix SparseRow arow = m1.sparseRows[0]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); int rlix = (rl==0) ? 0 : arow.searchIndexesFirstGTE(rl); rlix = (rlix>=0) ? rlix : alen; for( int k=rlix; k<alen && aix[k]<ru; k++ ) if( b[aix[k]] != null && !b[aix[k]].isEmpty() ) { SparseRow brow = b[aix[k]]; int blen = brow.size(); int[] bix = brow.getIndexContainer(); double[] bvals = brow.getValueContainer(); vectMultiplyAdd(avals[k], bvals, c, bix, 0, blen); } } } else //MATRIX-MATRIX { for( int i=rl, cix=rl*n; i<Math.min(ru, m1.sparseRows.length); i++, cix+=n ) { SparseRow arow = m1.sparseRows[i]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); for(int k = 0; k < alen; k++) { double val = avals[k]; SparseRow brow = b[ aix[k] ]; if( brow != null && !brow.isEmpty() ) { int blen = brow.size(); int[] bix = brow.getIndexContainer(); double[] bvals = brow.getValueContainer(); vectMultiplyAdd(val, bvals, c, bix, cix, blen); } } } } } } else { for( int i=rl, cix=rl*n; i<Math.min(ru, m1.sparseRows.length); i++, cix+=n ) { SparseRow arow = m1.sparseRows[i]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); for(int k = 0; k < alen; k++) { double val = avals[k]; SparseRow brow = m2.sparseRows[ aix[k] ]; if( brow != null && !brow.isEmpty() ) { int blen = brow.size(); int[] bix = brow.getIndexContainer(); double[] bvals = brow.getValueContainer(); for(int j = 0; j < blen; j++) c[cix+bix[j]] += val * bvals[j]; } } } } } } /** * This implementation applies to any combination of dense/sparse if at least one * input is ultrasparse (sparse and very few nnz). In that case, most importantly, * we want to create a sparse output and only iterate over the few nnz as the major * dimension. Low-level optimization have less importance in that case and having * this generic implementation helps to reduce the implementations from (2+1)^2 * to 2^2+1. * * @param m1 * @param m2 * @param ret * @throws DMLRuntimeException */ private static void matrixMultUltraSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru) throws DMLRuntimeException { boolean leftUS = m1.isUltraSparse(); final int m = m1.rlen; final int cd = m1.clen; final int n = m2.clen; if( leftUS ) //left is ultra-sparse (IKJ) { boolean rightSparse = m2.sparse; for( int i=rl; i<ru; i++ ) { SparseRow arow = m1.sparseRows[ i ]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aixs = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); if( alen==1 && avals[0]==1 ) //ROW SELECTION (no aggregation) { int aix = aixs[0]; if( rightSparse ) { //sparse right matrix (full row copy) if( m2.sparseRows!=null && m2.sparseRows[aix]!=null ) { ret.rlen=m; ret.allocateSparseRowsBlock(false); //allocation on demand ret.sparseRows[i] = new SparseRow(m2.sparseRows[aix]); ret.nonZeros += ret.sparseRows[i].size(); } } else { //dense right matrix (append all values) for( int j=0; j<n; j++ ) ret.appendValue(i, j, m2.quickGetValue(aix, j)); } } else //GENERAL CASE { for( int k=0; k<alen; k++ ) { double aval = avals[k]; int aix = aixs[k]; for( int j=0; j<n; j++ ) { double cval = ret.quickGetValue(i, j); double cvald = aval*m2.quickGetValue(aix, j); if( cvald != 0 ) ret.quickSetValue(i, j, cval+cvald); } } } } } } else //right is ultra-sparse (KJI) { for(int k = 0; k < cd; k++ ) { SparseRow brow = m2.sparseRows[ k ]; if( brow != null && !brow.isEmpty() ) { int blen = brow.size(); int[] bixs = brow.getIndexContainer(); double[] bvals = brow.getValueContainer(); for( int j=0; j<blen; j++ ) { double bval = bvals[j]; int bix = bixs[j]; for( int i=rl; i<ru; i++ ) { double cvald = bval*m1.quickGetValue(i, k); if( cvald != 0 ){ double cval = ret.quickGetValue(i, bix); ret.quickSetValue(i, bix, cval+cvald); } } } } } } //no need to recompute nonzeros because maintained internally } /** * * @param mX * @param mV * @param mW * @param ret * @param ct * @throws DMLRuntimeException * @throws DMLUnsupportedOperationException */ private static void matrixMultChainDense(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int rl, int ru) { double[] a = mX.denseBlock; double[] b = mV.denseBlock; double[] w = (mW!=null) ? mW.denseBlock : null; double[] c = ret.denseBlock; final int cd = mX.clen; //features in X boolean weights = (ct == ChainType.XtwXv); //temporary array for cache blocking //(blocksize chosen to fit b+v in L2 (256KB) for default 1k blocks) final int blocksize = 24; // constraint: factor of 4 double[] tmp = new double[blocksize]; //blockwise mmchain computation final int bn = ru - ru % blocksize; //rl blocksize aligned for( int bi=rl; bi < bn; bi+=blocksize ) { //compute 1st matrix-vector for row block for( int j=0, aix=bi*cd; j < blocksize; j++, aix+=cd) tmp[j] = dotProduct(a, b, aix, 0, cd); //multiply weights (in-place), if required if( weights ) vectMultiply(w, tmp, bi, 0, blocksize); //compute 2nd matrix vector for row block and aggregate for (int j=0, aix=bi*cd; j < blocksize; j+=4, aix+=4*cd) vectMultiplyAdd4(tmp[j], tmp[j+1], tmp[j+2], tmp[j+3], a, c, aix, aix+cd, aix+2*cd, aix+3*cd, 0, cd); } //compute rest (not aligned to blocksize) for( int i=bn, aix=bn*cd; i < ru; i++, aix+=cd ) { double val = dotProduct(a, b, aix, 0, cd); val *= (weights) ? w[i] : 1; vectMultiplyAdd(val, a, c, aix, 0, cd); } } /** * * @param mX * @param mV * @param mW * @param ret * @param ct * @param rl * @param ru * @throws DMLRuntimeException * @throws DMLUnsupportedOperationException */ private static void matrixMultChainSparse(MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int rl, int ru) { SparseRow[] a = mX.sparseRows; double[] b = mV.denseBlock; double[] w = (mW!=null) ? mW.denseBlock : null; double[] c = ret.denseBlock; boolean weights = (ct == ChainType.XtwXv); //temporary array for cache blocking //(blocksize chosen to fit b+v in L2 (256KB) for default 1k blocks) final int blocksize = 24; double[] tmp = new double[blocksize]; //blockwise mmchain computation for( int bi=rl; bi < ru; bi+=blocksize ) { //reset row block intermediate int tmplen = Math.min(blocksize, ru-bi); //compute 1st matrix-vector for row block for( int j=0; j < tmplen; j++) { SparseRow arow = a[bi+j]; if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); tmp[j] = dotProduct(avals, b, aix, 0, alen); } } //multiply weights (in-place), if required if( weights ) vectMultiply(w, tmp, bi, 0, tmplen); //compute 2nd matrix vector for row block and aggregate for( int j=0; j < tmplen; j++) { SparseRow arow = a[bi+j]; if( arow != null && !arow.isEmpty() && tmp[j] != 0 ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); vectMultiplyAdd(tmp[j], avals, c, aix, 0, alen); } } } } /** * * @param m1 * @param ret * @param leftTranspose * @throws DMLRuntimeException * @throws DMLUnsupportedOperationException */ private static void matrixMultTransposeSelfDense( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) throws DMLRuntimeException { //2) transpose self matrix multiply dense // (compute only upper-triangular matrix due to symmetry) double[] a = m1.denseBlock; double[] c = ret.denseBlock; int m = m1.rlen; int n = m1.clen; if( leftTranspose ) // t(X)%*%X { if( LOW_LEVEL_OPTIMIZATION ) { if( n==1 ) //VECTOR (col) { c[0] = dotProduct(a, a, m); } else //MATRIX { //1) Unrolled inner loop (for better instruction-level parallelism) //2) Blocked execution (for less cache trashing in parallel exec) //3) Asymmetric block sizes (for less misses in inner loop, yet blocks in L1/L2) final int blocksizeI = 32; //64//256KB c block (typical L2 size per core), 32KB a block final int blocksizeK = 24; //64//256KB b block (typical L2 size per core), used while read 512B of a / read/write 4KB of c final int blocksizeJ = 1024; //512//4KB (typical main-memory page size), for scan //temporary arrays (nnz a, b index) double[] ta = new double[ blocksizeK ]; int[] tbi = new int[ blocksizeK ]; final int mx = ru; final int cdx = m; final int nx = n; //blocked execution for( int bi = rl; bi < mx; bi+=blocksizeI ) //from bi due to symmetry for( int bk = 0, bimin = Math.min(mx, bi+blocksizeI); bk < cdx; bk+=blocksizeK ) for( int bj = bi, bkmin = Math.min(cdx, bk+blocksizeK); bj < nx; bj+=blocksizeJ ) { int bklen = bkmin-bk; int bjlen = Math.min(nx, bj+blocksizeJ)-bj; //core sub block matrix multiplication for( int i = bi; i < bimin; i++) { int aixi = bk*n +i; //start index on a (logical t(X)) int cixj = i * nx + bj; //scan index on c //determine nnz of a (for sparsity-aware skipping of rows) int knnz = copyNonZeroElements(a, aixi, bk, bj, n, nx, ta, tbi, bklen); //rest not aligned to blocks of 4 rows final int bn = knnz % 4; switch( bn ){ case 1: vectMultiplyAdd(ta[0], a, c, tbi[0], cixj, bjlen); break; case 2: vectMultiplyAdd2(ta[0],ta[1], a, c, tbi[0], tbi[1], cixj, bjlen); break; case 3: vectMultiplyAdd3(ta[0],ta[1],ta[2], a, c, tbi[0], tbi[1],tbi[2], cixj, bjlen); break; } //compute blocks of 4 rows (core inner loop) for( int k = bn; k<knnz; k+=4 ){ vectMultiplyAdd4( ta[k], ta[k+1], ta[k+2], ta[k+3], a, c, tbi[k], tbi[k+1], tbi[k+2], tbi[k+3], cixj, bjlen ); } } } } } else { for(int k = 0, ix1 = 0; k < m; k++, ix1+=n) for(int i = rl, ix3 = 0; i < ru; i++, ix3+=n) { double val = a[ ix1+i ]; if( val != 0 ) { for(int j = i; j < n; j++) //from i due to symmetry c[ ix3+j ] += val * a[ ix1+j ]; } } } } else // X%*%t(X) { if(LOW_LEVEL_OPTIMIZATION) { if( m==1 ) //VECTOR { c[0] = dotProduct(a, a, n); } else //MATRIX { //algorithm: scan c, foreach ci,j: scan row of a and t(a) (IJK) //1) Unrolled inner loop, for better ILP //2) Blocked execution, for less cache trashing in parallel exec // (smaller block sizes would be slightly better, but consistent as is) //3) Single write in inner loop (transient intermediates) int blocksize = 64; for( int bi = rl; bi<ru; bi+=blocksize ) for( int bj = bi; bj<m; bj+=blocksize ) { final int bimin = Math.min(ru, bi+blocksize); final int bjmin = Math.min(m, bj+blocksize); for(int i = bi, ix1 = bi*n, ix3 = bi*m; i < bimin; i++, ix1+=n, ix3+=m) { final int bjmax = Math.max(i,bj); for(int j = bjmax, ix2 = bjmax*n; j <bjmin; j++, ix2+=n) //from i due to symmetry { c[ ix3+j ] = dotProduct(a, a, ix1, ix2, n); } } } } } else { for(int i = rl, ix1 = 0, ix3 = 0; i < ru; i++, ix1+=n, ix3+=m) for(int j = i, ix2 = i*n; j < m; j++, ix2+=n) //from i due to symmetry { double val = 0; for(int k = 0; k < n; k++) val += a[ ix1+k ] * a[ix2+k]; c[ ix3+j ] = val; } } } } /** * * @param out * @param leftTranspose * @throws DMLUnsupportedOperationException * @throws DMLRuntimeException */ private static void matrixMultTransposeSelfSparse( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) throws DMLRuntimeException { //2) transpose self matrix multiply sparse // (compute only upper-triangular matrix due to symmetry) double[] c = ret.denseBlock; int m = m1.rlen; int n = m1.clen; if( leftTranspose ) // t(X)%*%X { //only general case (because vectors always dense) //algorithm: scan rows, foreach row self join (KIJ) if( LOW_LEVEL_OPTIMIZATION ) { for( SparseRow arow : m1.sparseRows ) if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); int rlix = (rl==0) ? 0 : arow.searchIndexesFirstGTE(rl); rlix = (rlix>=0) ? rlix : alen; for(int i = rlix; i < alen && aix[i]<ru; i++) { double val = avals[i]; if( val != 0 ) { int ix2 = aix[i]*n; vectMultiplyAdd(val, avals, c, aix, i, ix2, alen); } } } } else { for( SparseRow arow : m1.sparseRows ) if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); int rlix = (rl==0) ? 0 : arow.searchIndexesFirstGTE(rl); rlix = (rlix>=0) ? rlix : alen; for(int i = rlix; i < alen && aix[i]<ru; i++) { double val = avals[i]; if( val != 0 ) for(int j = i, ix2 = aix[i]*n; j < alen; j++) c[ix2+aix[j]] += val * avals[j]; } } } } else // X%*%t(X) { if( m==1 ) //VECTOR { SparseRow arow = m1.sparseRows[0]; if( arow !=null && !arow.isEmpty() ) { int alen = arow.size(); double[] avals = arow.getValueContainer(); c[0] = dotProduct(avals, avals, alen); } } else //MATRIX { //note: reorg to similar layout as t(X)%*%X because faster than //direct computation with IJK (no dependencies/branches in inner loop) //see preprocessMatrixMultTransposeSelf m1<-tmpBlock m = m1.clen; n = m1.rlen; //algorithm: scan rows, foreach row self join (KIJ) if( LOW_LEVEL_OPTIMIZATION ) { for( SparseRow arow : m1.sparseRows ) if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); int rlix = (rl==0) ? 0 : arow.searchIndexesFirstGTE(rl); rlix = (rlix>=0) ? rlix : alen; for(int i = rlix; i < alen && aix[i]<ru; i++) { double val = avals[i]; if( val != 0 ) { int ix2 = aix[i]*m; vectMultiplyAdd(val, avals, c, aix, i, ix2, alen); } } } } else { for( SparseRow arow : m1.sparseRows ) if( arow != null && !arow.isEmpty() ) { int alen = arow.size(); int[] aix = arow.getIndexContainer(); double[] avals = arow.getValueContainer(); int rlix = (rl==0) ? 0 : arow.searchIndexesFirstGTE(rl); rlix = (rlix>=0) ? rlix : alen; for(int i = rlix; i < alen && aix[i]<ru; i++) { double val = avals[i]; if( val != 0 ) for(int j = i, ix2 = aix[i]*m; j < alen; j++) c[ix2+aix[j]] += val * avals[j]; } } } } } } /** * * @param pm1 * @param m2 * @param ret1 * @param ret2 * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultPermuteDense(MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru) throws DMLRuntimeException { double[] a = pm1.denseBlock; double[] b = m2.denseBlock; double[] c = ret1.denseBlock; final int n = m2.clen; final int brlen = ret1.getNumRows(); int lastblk = -1; for( int i=rl, bix=rl*n; i<ru; i++, bix+=n ) { //compute block index and in-block indexes int pos = UtilFunctions.toInt( a[ i ]); //safe cast if( pos > 0 ) //selected row { int bpos = (pos-1) % brlen; int blk = (pos-1) / brlen; //allocate and switch to second output block //(never happens in cp, correct for multi-threaded usage) if( lastblk!=-1 && lastblk<blk ){ ret2.sparse = false; ret2.allocateDenseBlock(); c = ret2.denseBlock; } //memcopy entire dense row into target position System.arraycopy(b, bix, c, bpos*n, n); lastblk = blk; } } } /** * * @param pm1 * @param m2 * @param ret1 * @param ret2 * @param rl * @param ru */ private static void matrixMultPermuteDenseSparse( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru) { double[] a = pm1.denseBlock; double[] b = m2.denseBlock; SparseRow[] c = ret1.sparseRows; final int n = m2.clen; final int brlen = ret1.getNumRows(); int lastblk = -1; for( int i=rl, bix=rl*n; i<ru; i++, bix+=n ) { //compute block index and in-block indexes int pos = UtilFunctions.toInt( a[ i ]); //safe cast if( pos > 0 ) //selected row { int bpos = (pos-1) % brlen; int blk = (pos-1) / brlen; //allocate and switch to second output block //(never happens in cp, correct for multi-threaded usage) if( lastblk!=-1 && lastblk<blk ){ ret2.sparse = true; ret2.rlen=ret1.rlen; ret2.allocateSparseRowsBlock(); c = ret2.sparseRows; } //append entire dense row into sparse target position c[bpos] = new SparseRow( n ); for( int j=0; j<n; j++ ) c[bpos].append(j, b[bix+j]); lastblk = blk; } } } /** * * @param pm1 * @param m2 * @param ret1 * @param ret2 * @param rl * @param ru */ private static void matrixMultPermuteSparse( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru) { double[] a = pm1.denseBlock; SparseRow[] b = m2.sparseRows; SparseRow[] c = ret1.sparseRows; final int brlen = ret1.getNumRows(); int lastblk = -1; for( int i=rl; i<ru; i++ ) { //compute block index and in-block indexes int pos = UtilFunctions.toInt( a[ i ]); //safe cast if( pos > 0 ) //selected row { int bpos = (pos-1) % brlen; int blk = (pos-1) / brlen; //allocate and switch to second output block //(never happens in cp, correct for multi-threaded usage) if( lastblk!=-1 && lastblk<blk ){ ret2.sparse = true; ret2.allocateSparseRowsBlock(); c = ret2.sparseRows; } //memcopy entire sparse row into target position if( b[i] != null ) c[bpos] = new SparseRow( b[i] ); lastblk = blk; } } } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru */ private static void matrixMultWSLossDense(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, WeightsType wt, int rl, int ru) { double[] x = mX.denseBlock; double[] u = mU.denseBlock; double[] v = mV.denseBlock; double[] w = (mW!=null)? mW.denseBlock : null; final int n = mX.clen; final int cd = mU.clen; double wsloss = 0; // approach: iterate over all cells of X //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution for( int bi = rl; bi < ru; bi+=blocksizeIJ ) for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); // Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) if( wt==WeightsType.POST ) { for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { double wij = w[ix+j]; if( wij != 0 ) { double uvij = dotProduct(u, v, uix, vix, cd); wsloss += wij*(x[ix+j]-uvij)*(x[ix+j]-uvij); //^2 } } } // Pattern 1b) sum ((X!=0) * (X - U %*% t(V)) ^ 2) (post_nz weighting) else if( wt==WeightsType.POST_NZ ) { for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { double xij = x[ix+j]; if( xij != 0 ) { double uvij = dotProduct(u, v, uix, vix, cd); wsloss += (xij-uvij)*(xij-uvij); //^2 } } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) else if( wt==WeightsType.PRE ) { for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { double wij = w[ix+j]; double uvij = 0; if( wij != 0 ) uvij = dotProduct(u, v, uix, vix, cd); wsloss += (x[ix+j]-wij*uvij)*(x[ix+j]-wij*uvij); //^2 } } // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) else if( wt==WeightsType.NONE ) { for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { double uvij = dotProduct(u, v, uix, vix, cd); wsloss += (x[ix+j]-uvij)*(x[ix+j]-uvij); //^2 } } } ret.quickSetValue(0, 0, wsloss); } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru */ private static void matrixMultWSLossSparseDense(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, WeightsType wt, int rl, int ru) { SparseRow[] x = mX.sparseRows; SparseRow[] w = (mW!=null)? mW.sparseRows : null; double[] u = mU.denseBlock; double[] v = mV.denseBlock; final int n = mX.clen; final int cd = mU.clen; double wsloss = 0; // Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) if( wt==WeightsType.POST ) { // approach: iterate over W, point-wise in order to exploit sparsity for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) if( w[i] != null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); for( int k=0; k<wlen; k++ ) { double xi = mX.quickGetValue(i, wix[k]); double uvij = dotProduct(u, v, uix, wix[k]*cd, cd); wsloss += wval[k]*(xi-uvij)*(xi-uvij); } } } // Pattern 1b) sum ((X!=0) * (X - U %*% t(V)) ^ 2) (post weighting) else if( wt==WeightsType.POST_NZ ) { // approach: iterate over W, point-wise in order to exploit sparsity for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) if( x[i] != null && !x[i].isEmpty() ) { int xlen = x[i].size(); int[] xix = x[i].getIndexContainer(); double[] xval = x[i].getValueContainer(); for( int k=0; k<xlen; k++ ) { double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); wsloss += (xval[k]-uvij)*(xval[k]-uvij); } } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) else if( wt==WeightsType.PRE ) { // approach: iterate over all cells of X maybe sparse and dense // (note: tuning similar to pattern 3 possible but more complex) for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) for( int j=0, vix=0; j<n; j++, vix+=cd) { double xij = mX.quickGetValue(i, j); double wij = mW.quickGetValue(i, j); double uvij = 0; if( wij != 0 ) uvij = dotProduct(u, v, uix, vix, cd); wsloss += (xij-wij*uvij)*(xij-wij*uvij); } } // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) else if( wt==WeightsType.NONE ) { // approach: iterate over all cells of X and for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) { if( x[i]==null || x[i].isEmpty() ) { //empty row for( int j=0, vix=0; j<n; j++, vix+=cd) { double uvij = dotProduct(u, v, uix, vix, cd); wsloss += (-uvij)*(-uvij); } } else { //non-empty row int xlen = x[i].size(); int[] xix = x[i].getIndexContainer(); double[] xval = x[i].getValueContainer(); int last = -1; for( int k=0; k<xlen; k++ ) { //process last nnz til current nnz for( int k2=last+1; k2<xix[k]; k2++ ){ double uvij = dotProduct(u, v, uix, k2*cd, cd); wsloss += (-uvij)*(-uvij); } //process current nnz double uvij = dotProduct(u, v, uix, xix[k]*cd, cd); wsloss += (xval[k]-uvij)*(xval[k]-uvij); last = xix[k]; } //process last nnz til end of row for( int k2=last+1; k2<n; k2++ ) { double uvij = dotProduct(u, v, uix, k2*cd, cd); wsloss += (-uvij)*(-uvij); } } } } ret.quickSetValue(0, 0, wsloss); } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru */ private static void matrixMultWSLossGeneric (MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, WeightsType wt, int rl, int ru) { final int n = mX.clen; final int cd = mU.clen; double wsloss = 0; // Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) if( wt==WeightsType.POST ) { // approach: iterate over W, point-wise in order to exploit sparsity if( mW.sparse ) //SPARSE { SparseRow[] wrows = mW.sparseRows; for( int i=rl; i<ru; i++ ) if( wrows[i] != null && !wrows[i].isEmpty() ){ int wlen = wrows[i].size(); int[] wix = wrows[i].getIndexContainer(); double[] wval = wrows[i].getValueContainer(); for( int k=0; k<wlen; k++ ) { double uvij = dotProductGeneric(mU, mV, i, wix[k], cd); double xi = mX.quickGetValue(i, wix[k]); wsloss += wval[k]*(xi-uvij)*(xi-uvij); } } } else //DENSE { double[] w = mW.denseBlock; for( int i=rl, wix=rl*n; i<ru; i++, wix+=n ) for( int j=0; j<n; j++) if( w[wix+j] != 0 ) { double uvij = dotProductGeneric(mU, mV, i, j, cd); double xij = mX.quickGetValue(i, j); wsloss += w[wix+j]*(xij-uvij)*(xij-uvij); } } } // Pattern 1b) sum ((X!=0) * (X - U %*% t(V)) ^ 2) (post weighting) else if( wt==WeightsType.POST_NZ ) { // approach: iterate over W, point-wise in order to exploit sparsity if( mW.sparse ) //SPARSE { SparseRow[] xrows = mX.sparseRows; for( int i=rl; i<ru; i++ ) if( xrows[i] != null && !xrows[i].isEmpty() ){ int xlen = xrows[i].size(); int[] xix = xrows[i].getIndexContainer(); double[] xval = xrows[i].getValueContainer(); for( int k=0; k<xlen; k++ ) { double uvij = dotProductGeneric(mU, mV, i, xix[k], cd); wsloss += (xval[k]-uvij)*(xval[k]-uvij); } } } else //DENSE { double[] x = mX.denseBlock; for( int i=rl, xix=rl*n; i<ru; i++, xix+=n ) for( int j=0; j<n; j++) if( x[xix+j] != 0 ) { double uvij = dotProductGeneric(mU, mV, i, j, cd); wsloss += (x[xix+j]-uvij)*(x[xix+j]-uvij); } } } // Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) else if( wt==WeightsType.PRE ) { // approach: iterate over all cells of X maybe sparse and dense for( int i=rl; i<ru; i++ ) for( int j=0; j<n; j++) { double xij = mX.quickGetValue(i, j); double wij = mW.quickGetValue(i, j); double uvij = 0; if( wij != 0 ) uvij = dotProductGeneric(mU, mV, i, j, cd); wsloss += (xij-wij*uvij)*(xij-wij*uvij); } } // Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) else if( wt==WeightsType.NONE ) { // approach: iterate over all cells of X and for( int i=rl; i<ru; i++ ) for( int j=0; j<n; j++) { double xij = mX.quickGetValue(i, j); double uvij = dotProductGeneric(mU, mV, i, j, cd); wsloss += (xij-uvij)*(xij-uvij); } } ret.quickSetValue(0, 0, wsloss); } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultWSigmoidDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) throws DMLRuntimeException { double[] w = mW.denseBlock; double[] c = ret.denseBlock; double[] u = mU.denseBlock; double[] v = mV.denseBlock; final int n = mW.clen; final int cd = mU.clen; //note: cannot compute U %*% t(V) in-place of result w/ regular mm because //t(V) comes in transformed form and hence would require additional memory boolean flagminus = (wt==WSigmoidType.MINUS || wt==WSigmoidType.LOG_MINUS); boolean flaglog = (wt==WSigmoidType.LOG || wt==WSigmoidType.LOG_MINUS); //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution for( int bi = rl; bi < ru; bi+=blocksizeIJ ) for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); //core wsigmoid computation for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { double wij = w[ix+j]; if( wij != 0 ) c[ix+j] = wsigmoid(wij, u, v, uix, vix, flagminus, flaglog, cd); } } } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultWSigmoidSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) throws DMLRuntimeException { SparseRow[] w = mW.sparseRows; SparseRow[] c = ret.sparseRows; double[] u = mU.denseBlock; double[] v = mV.denseBlock; final int n = mW.clen; final int cd = mU.clen; boolean flagminus = (wt==WSigmoidType.MINUS || wt==WSigmoidType.LOG_MINUS); boolean flaglog = (wt==WSigmoidType.LOG || wt==WSigmoidType.LOG_MINUS); //approach: iterate over non-zeros of w, selective mm computation for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) if( w[i] != null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); c[i] = new SparseRow(wlen, n); for( int k=0; k<wlen; k++ ) { double cval = wsigmoid(wval[k], u, v, uix, wix[k]*cd, flagminus, flaglog, cd); c[i].append(wix[k], cval); } } } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultWSigmoidGeneric (MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) throws DMLRuntimeException { final int n = mW.clen; final int cd = mU.clen; boolean flagminus = (wt==WSigmoidType.MINUS || wt==WSigmoidType.LOG_MINUS); boolean flaglog = (wt==WSigmoidType.LOG || wt==WSigmoidType.LOG_MINUS); //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE { //w and c always in same representation SparseRow[] w = mW.sparseRows; SparseRow[] c = ret.sparseRows; for( int i=rl; i<ru; i++ ) if( w[i] != null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); c[i] = new SparseRow(wlen, n); for( int k=0; k<wlen; k++ ) { double cval = wsigmoid(wval[k], mU, mV, i, wix[k], flagminus, flaglog, cd); c[i].append(wix[k], cval); } } } else //DENSE { //w and c always in same representation double[] w = mW.denseBlock; double[] c = ret.denseBlock; for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++) { double wij = w[ix]; if( wij != 0 ) { c[ix] = wsigmoid(wij, mU, mV, i, j, flagminus, flaglog, cd); } } } } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultWDivMMDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) throws DMLRuntimeException { final boolean basic = wt.isBasic(); final boolean left = wt.isLeft(); final boolean mult = wt.isMult(); final boolean minus = wt.isMinus(); final int n = mW.clen; final int cd = mU.clen; double[] w = mW.denseBlock; double[] u = mU.denseBlock; double[] v = mV.denseBlock; double[] c = ret.denseBlock; //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution for( int bi = rl; bi < ru; bi+=blocksizeIJ ) for( int bj = cl, bimin = Math.min(ru, bi+blocksizeIJ); bj < cu; bj+=blocksizeIJ ) { int bjmin = Math.min(cu, bj+blocksizeIJ); //core wsigmoid computation for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) if( w[ix+j] != 0 ) { if( basic ) c[ix+j] = w[ix+j] * dotProduct(u, v, uix, vix, cd); else //left/right minus/default wdivmm(w[ix+j], u, v, c, uix, vix, left, mult, minus, cd); } } } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultWDivMMSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) throws DMLRuntimeException { final boolean basic = wt.isBasic(); final boolean left = wt.isLeft(); final boolean mult = wt.isMult(); final boolean minus = wt.isMinus(); final int cd = mU.clen; SparseRow[] w = mW.sparseRows; double[] u = mU.denseBlock; double[] v = mV.denseBlock; double[] c = ret.denseBlock; //approach: iterate over non-zeros of w, selective mm computation for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) { if( w[i] != null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); if( basic ) { for( int k=0; k<wlen; k++ ) ret.appendValue( i, wix[k], wval[k] * dotProduct(u, v, uix, wix[k]*cd, cd)); } else { //left/right minus default int k = (cl==0) ? 0 : w[i].searchIndexesFirstGTE(cl); k = (k>=0) ? k : wlen; for( ; k<wlen && wix[k]<cu; k++ ) wdivmm(wval[k], u, v, c, uix, wix[k]*cd, left, mult, minus, cd); } } } } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru * @throws DMLRuntimeException */ private static void matrixMultWDivMMGeneric(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) throws DMLRuntimeException { final boolean basic = wt.isBasic(); final boolean left = wt.isLeft(); final boolean mult = wt.isMult(); final boolean minus = wt.isMinus(); final int n = mW.clen; final int cd = mU.clen; //output always in dense representation double[] c = ret.denseBlock; //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE { SparseRow[] w = mW.sparseRows; for( int i=rl; i<ru; i++ ) { if( w[i] != null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); int k = (cl==0) ? 0 : w[i].searchIndexesFirstGTE(cl); k = (k>=0) ? k : wlen; for( ; k<wlen && wix[k]<cu; k++ ) { if( basic ) { double uvij = dotProductGeneric(mU,mV, i, wix[k], cd); ret.appendValue(i, wix[k], uvij); } else { //left/right minus/default wdivmm(wval[k], mU, mV, c, i, wix[k], left, mult, minus, cd); } } } } } else //DENSE { double[] w = mW.denseBlock; for( int i=rl, ix=rl*n; i<ru; i++, ix+=n ) for( int j=cl; j<cu; j++) if( w[ix+j] != 0 ) { if( basic ) { c[ix+j] = dotProductGeneric(mU,mV, i, j, cd); } else { //left/right minus/default wdivmm(w[ix+j], mU, mV, c, i, j, left, mult, minus, cd); } } } } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @param rl * @param ru */ private static void matrixMultWCeMMDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WCeMMType wt, int rl, int ru) { double[] w = mW.denseBlock; double[] u = mU.denseBlock; double[] v = mV.denseBlock; final int n = mW.clen; final int cd = mU.clen; double wceval = 0; // approach: iterate over all cells of X //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) final int blocksizeIJ = 16; //u/v block (max at typical L2 size) //blocked execution for( int bi = rl; bi < ru; bi+=blocksizeIJ ) for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ ) { int bjmin = Math.min(n, bj+blocksizeIJ); for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd ) for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) { double wij = w[ix+j]; if( wij != 0 ) { double uvij = dotProduct(u, v, uix, vix, cd); wceval += wij * FastMath.log(uvij); } } } ret.quickSetValue(0, 0, wceval); } /** * * @param mW * @param mU * @param mV * @param ret * @param wt * @param rl * @param ru */ private static void matrixMultWCeMMSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WCeMMType wt, int rl, int ru) { SparseRow[] w = mW.sparseRows; double[] u = mU.denseBlock; double[] v = mV.denseBlock; final int cd = mU.clen; double wceval = 0; // approach: iterate over all cells of X and for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) { if( w[i]!=null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); for( int k=0; k<wlen; k++ ) { double uvij = dotProduct(u, v, uix, wix[k]*cd, cd); wceval += wval[k] * FastMath.log(uvij); } } } ret.quickSetValue(0, 0, wceval); } /** * * @param mX * @param mU * @param mV * @param mW * @param ret * @param wt * @param rl * @param ru */ private static void matrixMultWCeMMGeneric(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WCeMMType wt, int rl, int ru) { final int n = mW.clen; final int cd = mU.clen; double wceval = 0; //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE { SparseRow[] w = mW.sparseRows; for( int i=rl; i<ru; i++ ) if( w[i] != null && !w[i].isEmpty() ) { int wlen = w[i].size(); int[] wix = w[i].getIndexContainer(); double[] wval = w[i].getValueContainer(); for( int k=0; k<wlen; k++ ) { double uvij = dotProductGeneric(mU, mV, i, wix[k], cd); wceval += wval[k] * FastMath.log(uvij); } } } else //DENSE { double[] w = mW.denseBlock; for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++) { double wij = w[ix]; if( wij != 0 ) { double uvij = dotProductGeneric(mU, mV, i, j, cd); wceval += wij * FastMath.log(uvij); } } } ret.quickSetValue(0, 0, wceval); } //////////////////////////////////////////// // performance-relevant utility functions // //////////////////////////////////////////// /** * Computes the dot-product of two vectors. Experiments (on long vectors of * 10^7 values) showed that this generic function provides equivalent performance * even for the specific case of dotProduct(a,a,len) as used for TSMM. * * @param a * @param b * @param len * @return */ private static double dotProduct( double[] a, double[] b, final int len ) { double val = 0; final int bn = len%8; //compute rest for( int i = 0; i < bn; i++ ) val += a[ i ] * b[ i ]; //unrolled 8-block (for better instruction-level parallelism) for( int i = bn; i < len; i+=8 ) { //read 64B cachelines of a and b //compute cval' = sum(a * b) + cval val += a[ i+0 ] * b[ i+0 ] + a[ i+1 ] * b[ i+1 ] + a[ i+2 ] * b[ i+2 ] + a[ i+3 ] * b[ i+3 ] + a[ i+4 ] * b[ i+4 ] + a[ i+5 ] * b[ i+5 ] + a[ i+6 ] * b[ i+6 ] + a[ i+7 ] * b[ i+7 ]; } //scalar result return val; } /** * * @param a * @param b * @param ai * @param bi * @param len * @return */ private static double dotProduct( double[] a, double[] b, int ai, int bi, final int len ) { double val = 0; final int bn = len%8; //compute rest for( int i = 0; i < bn; i++, ai++, bi++ ) val += a[ ai ] * b[ bi ]; //unrolled 8-block (for better instruction-level parallelism) for( int i = bn; i < len; i+=8, ai+=8, bi+=8 ) { //read 64B cachelines of a and b //compute cval' = sum(a * b) + cval val += a[ ai+0 ] * b[ bi+0 ] + a[ ai+1 ] * b[ bi+1 ] + a[ ai+2 ] * b[ bi+2 ] + a[ ai+3 ] * b[ bi+3 ] + a[ ai+4 ] * b[ bi+4 ] + a[ ai+5 ] * b[ bi+5 ] + a[ ai+6 ] * b[ bi+6 ] + a[ ai+7 ] * b[ bi+7 ]; } //scalar result return val; } private static double dotProduct( double[] a, double[] b, int[] aix, final int bi, final int len ) { double val = 0; final int bn = len%8; //compute rest for( int i = 0; i < bn; i++ ) val += a[ i ] * b[ bi+aix[i] ]; //unrolled 8-block (for better instruction-level parallelism) for( int i = bn; i < len; i+=8 ) { //read 64B cacheline of a //read 64B of b via 'gather' //compute cval' = sum(a * b) + cval val += a[ i+0 ] * b[ bi+aix[i+0] ] + a[ i+1 ] * b[ bi+aix[i+1] ] + a[ i+2 ] * b[ bi+aix[i+2] ] + a[ i+3 ] * b[ bi+aix[i+3] ] + a[ i+4 ] * b[ bi+aix[i+4] ] + a[ i+5 ] * b[ bi+aix[i+5] ] + a[ i+6 ] * b[ bi+aix[i+6] ] + a[ i+7 ] * b[ bi+aix[i+7] ]; } //scalar result return val; } /** * * @param aval * @param b * @param c * @param bi * @param ci * @param len */ private static void vectMultiplyAdd( final double aval, double[] b, double[] c, int bi, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, bi++, ci++) c[ ci ] += aval * b[ bi ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, bi+=8, ci+=8) { //read 64B cachelines of b and c //compute c' = aval * b + c //write back 64B cacheline of c = c' c[ ci+0 ] += aval * b[ bi+0 ]; c[ ci+1 ] += aval * b[ bi+1 ]; c[ ci+2 ] += aval * b[ bi+2 ]; c[ ci+3 ] += aval * b[ bi+3 ]; c[ ci+4 ] += aval * b[ bi+4 ]; c[ ci+5 ] += aval * b[ bi+5 ]; c[ ci+6 ] += aval * b[ bi+6 ]; c[ ci+7 ] += aval * b[ bi+7 ]; } } /** * * @param aval1 * @param aval2 * @param b * @param c * @param bi * @param bi2 * @param ci * @param len */ private static void vectMultiplyAdd2( final double aval1, final double aval2, double[] b, double[] c, int bi1, int bi2, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, bi1++, bi2++, ci++ ) c[ ci ] += aval1 * b[ bi1 ] + aval2 * b[ bi2 ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, bi1+=8, bi2+=8, ci+=8 ) { //read 64B cachelines of b (2x) and c //compute c' = aval_1 * b_1 + aval_2 * b_2 + c //write back 64B cacheline of c = c' c[ ci+0 ] += aval1 * b[ bi1+0 ] + aval2 * b[ bi2+0 ]; c[ ci+1 ] += aval1 * b[ bi1+1 ] + aval2 * b[ bi2+1 ]; c[ ci+2 ] += aval1 * b[ bi1+2 ] + aval2 * b[ bi2+2 ]; c[ ci+3 ] += aval1 * b[ bi1+3 ] + aval2 * b[ bi2+3 ]; c[ ci+4 ] += aval1 * b[ bi1+4 ] + aval2 * b[ bi2+4 ]; c[ ci+5 ] += aval1 * b[ bi1+5 ] + aval2 * b[ bi2+5 ]; c[ ci+6 ] += aval1 * b[ bi1+6 ] + aval2 * b[ bi2+6 ]; c[ ci+7 ] += aval1 * b[ bi1+7 ] + aval2 * b[ bi2+7 ]; } } /** * * @param aval1 * @param aval2 * @param aval3 * @param b * @param c * @param bi1 * @param bi2 * @param bi3 * @param ci * @param len */ private static void vectMultiplyAdd3( final double aval1, final double aval2, final double aval3, double[] b, double[] c, int bi1, int bi2, int bi3, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, bi1++, bi2++, bi3++, ci++ ) c[ ci ] += aval1 * b[ bi1 ] + aval2 * b[ bi2 ] + aval3 * b[ bi3 ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, bi1+=8, bi2+=8, bi3+=8, ci+=8 ) { //read 64B cachelines of b (3x) and c //compute c' = aval_1 * b_1 + aval_2 * b_2 + c //write back 64B cacheline of c = c' c[ ci+0 ] += aval1 * b[ bi1+0 ] + aval2 * b[ bi2+0 ] + aval3 * b[ bi3+0 ]; c[ ci+1 ] += aval1 * b[ bi1+1 ] + aval2 * b[ bi2+1 ] + aval3 * b[ bi3+1 ]; c[ ci+2 ] += aval1 * b[ bi1+2 ] + aval2 * b[ bi2+2 ] + aval3 * b[ bi3+2 ]; c[ ci+3 ] += aval1 * b[ bi1+3 ] + aval2 * b[ bi2+3 ] + aval3 * b[ bi3+3 ]; c[ ci+4 ] += aval1 * b[ bi1+4 ] + aval2 * b[ bi2+4 ] + aval3 * b[ bi3+4 ]; c[ ci+5 ] += aval1 * b[ bi1+5 ] + aval2 * b[ bi2+5 ] + aval3 * b[ bi3+5 ]; c[ ci+6 ] += aval1 * b[ bi1+6 ] + aval2 * b[ bi2+6 ] + aval3 * b[ bi3+6 ]; c[ ci+7 ] += aval1 * b[ bi1+7 ] + aval2 * b[ bi2+7 ] + aval3 * b[ bi3+7 ]; } } /** * * @param aval1 * @param aval2 * @param aval3 * @param aval4 * @param b * @param c * @param bi1 * @param bi2 * @param bi3 * @param bi4 * @param ci * @param len */ private static void vectMultiplyAdd4( final double aval1, final double aval2, final double aval3, final double aval4, double[] b, double[] c, int bi1, int bi2, int bi3, int bi4, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, bi1++, bi2++, bi3++, bi4++, ci++ ) c[ ci ] += aval1 * b[ bi1 ] + aval2 * b[ bi2 ] + aval3 * b[ bi3 ] + aval4 * b[ bi4 ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, bi1+=8, bi2+=8, bi3+=8, bi4+=8, ci+=8) { //read 64B cachelines of b (4x) and c //compute c' = aval_1 * b_1 + aval_2 * b_2 + c //write back 64B cacheline of c = c' c[ ci+0 ] += aval1 * b[ bi1+0 ] + aval2 * b[ bi2+0 ] + aval3 * b[ bi3+0 ] + aval4 * b[ bi4+0 ]; c[ ci+1 ] += aval1 * b[ bi1+1 ] + aval2 * b[ bi2+1 ] + aval3 * b[ bi3+1 ] + aval4 * b[ bi4+1 ]; c[ ci+2 ] += aval1 * b[ bi1+2 ] + aval2 * b[ bi2+2 ] + aval3 * b[ bi3+2 ] + aval4 * b[ bi4+2 ]; c[ ci+3 ] += aval1 * b[ bi1+3 ] + aval2 * b[ bi2+3 ] + aval3 * b[ bi3+3 ] + aval4 * b[ bi4+3 ]; c[ ci+4 ] += aval1 * b[ bi1+4 ] + aval2 * b[ bi2+4 ] + aval3 * b[ bi3+4 ] + aval4 * b[ bi4+4 ]; c[ ci+5 ] += aval1 * b[ bi1+5 ] + aval2 * b[ bi2+5 ] + aval3 * b[ bi3+5 ] + aval4 * b[ bi4+5 ]; c[ ci+6 ] += aval1 * b[ bi1+6 ] + aval2 * b[ bi2+6 ] + aval3 * b[ bi3+6 ] + aval4 * b[ bi4+6 ]; c[ ci+7 ] += aval1 * b[ bi1+7 ] + aval2 * b[ bi2+7 ] + aval3 * b[ bi3+7 ] + aval4 * b[ bi4+7 ]; } } /** * * @param aval * @param b * @param c * @param bix * @param ci * @param len */ private static void vectMultiplyAdd( final double aval, double[] b, double[] c, int[] bix, final int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++ ) c[ ci + bix[j] ] += aval * b[ j ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8 ) { //read 64B cacheline of b //read 64B of c via 'gather' //compute c' = aval * b + c //write back 64B of c = c' via 'scatter' c[ ci+bix[j+0] ] += aval * b[ j+0 ]; c[ ci+bix[j+1] ] += aval * b[ j+1 ]; c[ ci+bix[j+2] ] += aval * b[ j+2 ]; c[ ci+bix[j+3] ] += aval * b[ j+3 ]; c[ ci+bix[j+4] ] += aval * b[ j+4 ]; c[ ci+bix[j+5] ] += aval * b[ j+5 ]; c[ ci+bix[j+6] ] += aval * b[ j+6 ]; c[ ci+bix[j+7] ] += aval * b[ j+7 ]; } } private static void vectMultiplyAdd( final double aval, double[] b, double[] c, int[] bix, final int bi, final int ci, final int len ) { final int bn = (len-bi)%8; //rest, not aligned to 8-blocks for( int j = bi; j < bi+bn; j++ ) c[ ci + bix[j] ] += aval * b[ j ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bi+bn; j < len; j+=8 ) { //read 64B cacheline of b //read 64B of c via 'gather' //compute c' = aval * b + c //write back 64B of c = c' via 'scatter' c[ ci+bix[j+0] ] += aval * b[ j+0 ]; c[ ci+bix[j+1] ] += aval * b[ j+1 ]; c[ ci+bix[j+2] ] += aval * b[ j+2 ]; c[ ci+bix[j+3] ] += aval * b[ j+3 ]; c[ ci+bix[j+4] ] += aval * b[ j+4 ]; c[ ci+bix[j+5] ] += aval * b[ j+5 ]; c[ ci+bix[j+6] ] += aval * b[ j+6 ]; c[ ci+bix[j+7] ] += aval * b[ j+7 ]; } } /** * * @param aval * @param b * @param c * @param bi * @param ci * @param len */ private static void vectMultiplyWrite( final double aval, double[] b, double[] c, int bi, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, bi++, ci++) c[ ci ] = aval * b[ bi ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, bi+=8, ci+=8) { //read 64B cachelines of b and c //compute c' = aval * b + c //write back 64B cacheline of c = c' c[ ci+0 ] = aval * b[ bi+0 ]; c[ ci+1 ] = aval * b[ bi+1 ]; c[ ci+2 ] = aval * b[ bi+2 ]; c[ ci+3 ] = aval * b[ bi+3 ]; c[ ci+4 ] = aval * b[ bi+4 ]; c[ ci+5 ] = aval * b[ bi+5 ]; c[ ci+6 ] = aval * b[ bi+6 ]; c[ ci+7 ] = aval * b[ bi+7 ]; } } /** * * @param a * @param b * @param c * @param ai * @param bi * @param ci * @param len */ @SuppressWarnings("unused") private static void vectMultiplyWrite( double[] a, double[] b, double[] c, int ai, int bi, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, ai++, bi++, ci++) c[ ci ] = a[ ai ] * b[ bi ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, ai+=8, bi+=8, ci+=8) { //read 64B cachelines of a and b //compute c' = a * b //write back 64B cacheline of c = c' c[ ci+0 ] = a[ ai+0 ] * b[ bi+0 ]; c[ ci+1 ] = a[ ai+1 ] * b[ bi+1 ]; c[ ci+2 ] = a[ ai+2 ] * b[ bi+2 ]; c[ ci+3 ] = a[ ai+3 ] * b[ bi+3 ]; c[ ci+4 ] = a[ ai+4 ] * b[ bi+4 ]; c[ ci+5 ] = a[ ai+5 ] * b[ bi+5 ]; c[ ci+6 ] = a[ ai+6 ] * b[ bi+6 ]; c[ ci+7 ] = a[ ai+7 ] * b[ bi+7 ]; } } /** * * @param a * @param c * @param ai * @param ci * @param len */ private static void vectMultiply( double[] a, double[] c, int ai, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, ai++, ci++) c[ ci ] *= a[ ai ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, ai+=8, ci+=8) { //read 64B cachelines of a and c //compute c' = c * a //write back 64B cacheline of c = c' c[ ci+0 ] *= a[ ai+0 ]; c[ ci+1 ] *= a[ ai+1 ]; c[ ci+2 ] *= a[ ai+2 ]; c[ ci+3 ] *= a[ ai+3 ]; c[ ci+4 ] *= a[ ai+4 ]; c[ ci+5 ] *= a[ ai+5 ]; c[ ci+6 ] *= a[ ai+6 ]; c[ ci+7 ] *= a[ ai+7 ]; } } /** * * @param a * @param c * @param ai * @param ci * @param len */ private static void vectAdd( double[] a, double[] c, int ai, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, ai++, ci++) c[ ci ] += a[ ai ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, ai+=8, ci+=8) { //read 64B cachelines of a and c //compute c' = c * a //write back 64B cacheline of c = c' c[ ci+0 ] += a[ ai+0 ]; c[ ci+1 ] += a[ ai+1 ]; c[ ci+2 ] += a[ ai+2 ]; c[ ci+3 ] += a[ ai+3 ]; c[ ci+4 ] += a[ ai+4 ]; c[ ci+5 ] += a[ ai+5 ]; c[ ci+6 ] += a[ ai+6 ]; c[ ci+7 ] += a[ ai+7 ]; } } /** * * @param a1 * @param a2 * @param a3 * @param a4 * @param c * @param ai * @param ci * @param len */ private static void vectAdd4( double[] a1, double[] a2, double[] a3, double[] a4, double[] c, int ai, int ci, final int len ) { final int bn = len%8; //rest, not aligned to 8-blocks for( int j = 0; j < bn; j++, ai++, ci++) c[ ci ] += a1[ ai ] + a2[ ai ] + a3[ ai ] + a4[ ai ]; //unrolled 8-block (for better instruction-level parallelism) for( int j = bn; j < len; j+=8, ai+=8, ci+=8) { //read 64B cachelines of a (4x) and c //compute c' = c + a1 + a2 + a3 + a4 //write back 64B cacheline of c = c' c[ ci+0 ] += a1[ ai+0 ] + a2[ ai+0 ] + a3[ ai+0 ] + a4[ ai+0 ]; c[ ci+1 ] += a1[ ai+1 ] + a2[ ai+1 ] + a3[ ai+1 ] + a4[ ai+1 ]; c[ ci+2 ] += a1[ ai+2 ] + a2[ ai+2 ] + a3[ ai+2 ] + a4[ ai+2 ]; c[ ci+3 ] += a1[ ai+3 ] + a2[ ai+3 ] + a3[ ai+3 ] + a4[ ai+3 ]; c[ ci+4 ] += a1[ ai+4 ] + a2[ ai+4 ] + a3[ ai+4 ] + a4[ ai+4 ]; c[ ci+5 ] += a1[ ai+5 ] + a2[ ai+5 ] + a3[ ai+5 ] + a4[ ai+5 ]; c[ ci+6 ] += a1[ ai+6 ] + a2[ ai+6 ] + a3[ ai+6 ] + a4[ ai+6 ]; c[ ci+7 ] += a1[ ai+7 ] + a2[ ai+7 ] + a3[ ai+7 ] + a4[ ai+7 ]; } } /** * * @param wij * @param u * @param v * @param uix * @param vix * @param flagminus * @param flaglog * @param len * @return */ private static double wsigmoid( final double wij, double[] u, double[] v, final int uix, final int vix, final boolean flagminus, final boolean flaglog, final int len ) { //compute dot product over ui vj double uvij = dotProduct(u, v, uix, vix, len); //compute core sigmoid function double cval = flagminus ? 1 / (1 + FastMath.exp(uvij)) : 1 / (1 + FastMath.exp(-uvij)); //compute weighted output return wij * ((flaglog) ? FastMath.log(cval) : cval); } /** * * @param wij * @param u * @param v * @param uix * @param vix * @param flagminus * @param flaglog * @param len * @return */ private static double wsigmoid( final double wij, MatrixBlock u, MatrixBlock v, final int uix, final int vix, final boolean flagminus, final boolean flaglog, final int len ) { //compute dot product over ui vj double uvij = dotProductGeneric(u, v, uix, vix, len); //compute core sigmoid function double cval = flagminus ? 1 / (1 + FastMath.exp(uvij)) : 1 / (1 + FastMath.exp(-uvij)); //compute weighted output return wij * ((flaglog) ? FastMath.log(cval) : cval); } /** * * @param wij * @param u * @param v * @param c * @param uix * @param vix * @param flagleft * @param len */ private static void wdivmm( final double wij, double[] u, double[] v, double[] c, final int uix, final int vix, final boolean left, final boolean mult, final boolean minus, final int len ) { //compute dot product over ui vj double uvij = dotProduct(u, v, uix, vix, len); //compute core wdivmm double tmpval = minus ? uvij - wij : mult ? wij * uvij : wij / uvij; //prepare inputs for final mm int bix = left ? uix : vix; int cix = left ? vix : uix; double[] b = left ? u : v; //compute final mm output vectMultiplyAdd(tmpval, b, c, bix, cix, len); } /** * * @param wij * @param u * @param v * @param c * @param uix * @param vix * @param flagleft * @param len */ private static void wdivmm( final double wij, MatrixBlock u, MatrixBlock v, double[] c, final int uix, final int vix, final boolean left, boolean mult, final boolean minus, final int len ) { //compute dot product over ui vj double uvij = dotProductGeneric(u, v, uix, vix, len); //compute core wdivmm double wtmp = minus ? uvij - wij : mult ? wij * uvij : wij / uvij; //prepare inputs for final mm int bix = left ? uix : vix; int cix = left ? vix*len : uix*len; MatrixBlock b = left ? u : v; //compute final mm for( int k2=0; k2<len; k2++ ) c[cix+k2] += b.quickGetValue(bix, k2) * wtmp; } /** * * @param a * @param b * @param ai * @param bi * @param len * @return */ private static double dotProductGeneric(MatrixBlock a, MatrixBlock b, final int ai, final int bi, int len) { double val = 0; for( int k2=0; k2<len; k2++ ) val += a.quickGetValue(ai, k2) * b.quickGetValue(bi, k2); return val; } /** * Used for all version of TSMM where the result is known to be symmetric. * Hence, we compute only the upper triangular matrix and copy this partial * result down to lower triangular matrix once. * * @param ret */ private static void copyUpperToLowerTriangle( MatrixBlock ret ) { double[] c = ret.denseBlock; final int m = ret.rlen; final int n = ret.clen; //copy symmetric values for( int i=0, uix=0; i<m; i++, uix+=n ) for( int j=i+1, lix=j*n+i; j<n; j++, lix+=n ) c[ lix ] = c[ uix+j ]; } /** * * @param m1 * @param leftTranspose * @return * @throws DMLRuntimeException */ private static MatrixBlock prepMatrixMultTransposeSelfInput( MatrixBlock m1, boolean leftTranspose ) throws DMLRuntimeException { MatrixBlock ret = m1; if( !leftTranspose && m1.sparse && m1.rlen > 1) //X%*%t(X) SPARSE MATRIX { //directly via LibMatrixReorg in order to prevent sparsity change MatrixBlock tmpBlock = new MatrixBlock(m1.clen, m1.rlen, m1.sparse); LibMatrixReorg.reorg(m1, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject())); ret = tmpBlock; } return ret; } /** * * @param m1 * @param m2 * @return */ private static boolean checkPrepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 ) { //transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output return (!m1.sparse && !m2.sparse && m1.rlen>m2.clen && m2.rlen > 64 && m2.clen > 1 && m2.clen < 64); } /** * * @param m1 * @param m2 * @return */ private static boolean checkParMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2, int k ) { //parallelize over rows in rhs matrix if number of rows in lhs/output is very small return (m1.rlen==1 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && !(m1.isUltraSparse()||m2.isUltraSparse())) || (m1.rlen<=16 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && m2.rlen > m1.rlen && !m1.sparse && !m2.sparse && (long)k * 8 * m1.rlen * m2.clen < MEM_OVERHEAD_THRESHOLD ); } /** * * @param m1 * @param m2 * @return * @throws DMLRuntimeException */ private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 ) throws DMLRuntimeException { MatrixBlock ret = m2; //transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output if( checkPrepMatrixMultRightInput(m1, m2) ) { MatrixBlock tmpBlock = new MatrixBlock(m2.clen, m2.rlen, m2.sparse); LibMatrixReorg.reorg(m2, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject())); ret = tmpBlock; } return ret; } /** * * @param a * @param aixi * @param bk * @param bj * @param n * @param tmpa * @param tmpbi * @param bklen * @return */ private static int copyNonZeroElements( double[] a, final int aixi, final int bk, final int bj, final int n, double[] tmpa, int[] tmpbi, final int bklen ) { int knnz = 0; for( int k = 0; k < bklen; k++ ) if( a[ aixi+k ] != 0 ) { tmpa[ knnz ] = a[ aixi+k ]; tmpbi[ knnz ] = (bk+k) * n + bj; //scan index on b knnz ++; } return knnz; } /** * * @param a * @param aixi * @param bk * @param bj * @param n * @param nx * @param tmpa * @param tmpbi * @param bklen * @return */ private static int copyNonZeroElements( double[] a, int aixi, final int bk, final int bj, final int n, final int nx, double[] tmpa, int[] tmpbi, final int bklen ) { int knnz = 0; for( int k = 0; k < bklen; k++, aixi+=n ) if( a[ aixi ] != 0 ) { tmpa[ knnz ] = a[ aixi ]; tmpbi[ knnz ] = (bk+k) * nx + bj; //scan index on b knnz ++; } return knnz; } /** * * @param tasks * @param ret */ private static void sumScalarResults(ArrayList<ScalarResultTask> tasks, MatrixBlock ret) { //aggregate partial results double val = 0; for(ScalarResultTask task : tasks) val += task.getScalarResult(); ret.quickSetValue(0, 0, val); } /** * * @param partret * @param ret */ @SuppressWarnings("unused") private static void sumDenseResults( double[][] partret, double[] ret ) { final int len = ret.length; final int k = partret.length; final int bk = k % 4; final int blocksize = 2 * 1024; //16KB (half of common L1 data) //cache-conscious aggregation to prevent repreated scans/writes of ret for( int bi=0; bi<len; bi+=blocksize ) { int llen = Math.min(len-bi, blocksize); //aggregate next block from all partial results for( int j=0; j<bk; j++ ) //rest (not aligned to 4) vectAdd(partret[j], ret, bi, bi, llen); for( int j=bk; j<k; j+=4 ) //4 partial results at a time vectAdd4(partret[j], partret[j+1], partret[j+2], partret[j+3], ret, bi, bi, llen); } } ///////////////////////////////////////////////////////// // Task Implementations for Multi-Threaded Operations // ///////////////////////////////////////////////////////// /** * * */ private static class MatrixMultTask implements Callable<Object> { private MatrixBlock _m1 = null; private MatrixBlock _m2 = null; private MatrixBlock _ret = null; private boolean _tm2 = false; //transposed m2 private boolean _pm2 = false; //par over m2 private int _rl = -1; private int _ru = -1; private long _nnz = -1; protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean tm2, boolean pm2, int rl, int ru ) { _m1 = m1; _m2 = m2; _tm2 = tm2; _pm2 = pm2; _rl = rl; _ru = ru; if( pm2 ) { //vector-matrix / matrix-matrix //allocate local result for partial aggregation _ret = new MatrixBlock(ret.rlen, ret.clen, false); } else { //default case _ret = ret; } } @Override public Object call() throws DMLRuntimeException { //thread-local allocation if( _pm2 ) _ret.allocateDenseBlock(); //compute block matrix multiplication if( _m1.isUltraSparse() || _m2.isUltraSparse() ) matrixMultUltraSparse(_m1, _m2, _ret, _rl, _ru); else if(!_m1.sparse && !_m2.sparse) matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2, _rl, _ru); else if(_m1.sparse && _m2.sparse) matrixMultSparseSparse(_m1, _m2, _ret, _pm2, _rl, _ru); else if(_m1.sparse) matrixMultSparseDense(_m1, _m2, _ret, _pm2, _rl, _ru); else matrixMultDenseSparse(_m1, _m2, _ret, _pm2, _rl, _ru); //maintain block nnz (upper bounds inclusive) if( !_pm2 ) _nnz = _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1); return null; } public long getPartialNnz(){ return _nnz; } public MatrixBlock getResult(){ return _ret; } } /** * * */ private static class MatrixMultChainTask implements Callable<Object> { private MatrixBlock _m1 = null; private MatrixBlock _m2 = null; private MatrixBlock _m3 = null; private MatrixBlock _ret = null; private ChainType _ct = null; private int _rl = -1; private int _ru = -1; protected MatrixMultChainTask( MatrixBlock mX, MatrixBlock mV, MatrixBlock mW, MatrixBlock ret, ChainType ct, int rl, int ru ) throws DMLRuntimeException { _m1 = mX; _m2 = mV; _m3 = mW; _ct = ct; _rl = rl; _ru = ru; //allocate local result for partial aggregation _ret = new MatrixBlock(ret.rlen, ret.clen, false); _ret.allocateDenseBlock(); } @Override public Object call() throws DMLRuntimeException { if( _m1.sparse ) matrixMultChainSparse(_m1, _m2, _m3, _ret, _ct, _rl, _ru); else matrixMultChainDense(_m1, _m2, _m3, _ret, _ct, _rl, _ru); //NOTE: we dont do global aggregation from concurrent tasks in order //to prevent synchronization (sequential aggregation led to better //performance after JIT) return null; } public MatrixBlock getResult() { return _ret; } } private static class MatrixMultTransposeTask implements Callable<Object> { private MatrixBlock _m1 = null; private MatrixBlock _ret = null; private boolean _left = true; private int _rl = -1; private int _ru = -1; protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock ret, boolean left, int rl, int ru ) { _m1 = m1; _ret = ret; _left = left; _rl = rl; _ru = ru; } @Override public Object call() throws DMLRuntimeException { if( _m1.sparse ) matrixMultTransposeSelfSparse(_m1, _ret, _left, _rl, _ru); else matrixMultTransposeSelfDense(_m1, _ret, _left, _rl, _ru); return null; } } /** * * */ private static class MatrixMultPermuteTask implements Callable<Object> { private MatrixBlock _pm1 = null; private MatrixBlock _m2 = null; private MatrixBlock _ret1 = null; private MatrixBlock _ret2 = null; private int _rl = -1; private int _ru = -1; protected MatrixMultPermuteTask( MatrixBlock pm1, MatrixBlock m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru) { _pm1 = pm1; _m2 = m2; _ret1 = ret1; _ret2 = ret2; _rl = rl; _ru = ru; } @Override public Object call() throws DMLRuntimeException { if( _m2.sparse ) matrixMultPermuteSparse(_pm1, _m2, _ret1, _ret2, _rl, _ru); else if( _ret1.sparse ) matrixMultPermuteDenseSparse(_pm1, _m2, _ret1, _ret2, _rl, _ru); else matrixMultPermuteDense(_pm1, _m2, _ret1, _ret2, _rl, _ru); return null; } } /** * */ private static interface ScalarResultTask extends Callable<Object>{ public double getScalarResult(); } /** * * */ private static class MatrixMultWSLossTask implements ScalarResultTask { private MatrixBlock _mX = null; private MatrixBlock _mU = null; private MatrixBlock _mV = null; private MatrixBlock _mW = null; private MatrixBlock _ret = null; private WeightsType _wt = null; private int _rl = -1; private int _ru = -1; protected MatrixMultWSLossTask(MatrixBlock mX, MatrixBlock mU, MatrixBlock mV, MatrixBlock mW, WeightsType wt, int rl, int ru) throws DMLRuntimeException { _mX = mX; _mU = mU; _mV = mV; _mW = mW; _wt = wt; _rl = rl; _ru = ru; //allocate local result for partial aggregation _ret = new MatrixBlock(1, 1, false); _ret.allocateDenseBlock(); } @Override public Object call() throws DMLRuntimeException { if( !_mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || !_mW.sparse) && !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() && (_mW==null || !_mW.isEmptyBlock())) matrixMultWSLossDense(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru); else if( _mX.sparse && !_mU.sparse && !_mV.sparse && (_mW==null || _mW.sparse) && !_mX.isEmptyBlock() && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() && (_mW==null || !_mW.isEmptyBlock())) matrixMultWSLossSparseDense(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru); else matrixMultWSLossGeneric(_mX, _mU, _mV, _mW, _ret, _wt, _rl, _ru); return null; } @Override public double getScalarResult() { return _ret.quickGetValue(0, 0); } } /** * * */ private static class MatrixMultWSigmoidTask implements Callable<Object> { private MatrixBlock _mW = null; private MatrixBlock _mU = null; private MatrixBlock _mV = null; private MatrixBlock _ret = null; private WSigmoidType _wt = null; private int _rl = -1; private int _ru = -1; private long _nnz = -1; protected MatrixMultWSigmoidTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) throws DMLRuntimeException { _mW = mW; _mU = mU; _mV = mV; _ret = ret; _wt = wt; _rl = rl; _ru = ru; } @Override public Object call() throws DMLRuntimeException { //core weighted square sum mm computation if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() ) matrixMultWSigmoidDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru); else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock()) matrixMultWSigmoidSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru); else matrixMultWSigmoidGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru); //maintain block nnz (upper bounds inclusive) _nnz = _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1); return null; } public long getPartialNnz(){ return _nnz; } } /** * * */ private static class MatrixMultWDivTask implements Callable<Object> { private MatrixBlock _mW = null; private MatrixBlock _mU = null; private MatrixBlock _mV = null; private MatrixBlock _ret = null; private WDivMMType _wt = null; private int _rl = -1; private int _ru = -1; private int _cl = -1; private int _cu = -1; private long _nnz = -1; protected MatrixMultWDivTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WDivMMType wt, int rl, int ru, int cl, int cu) throws DMLRuntimeException { _mW = mW; _mU = mU; _mV = mV; _wt = wt; _rl = rl; _ru = ru; _cl = cl; _cu = cu; _ret = ret; } @Override public Object call() throws DMLRuntimeException { //core weighted div mm computation if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() ) matrixMultWDivMMDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru, _cl, _cu); else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock()) matrixMultWDivMMSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru, _cl, _cu); else matrixMultWDivMMGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru, _cl, _cu); //maintain partial nnz for right (upper bounds inclusive) int rl = _wt.isLeft() ? _cl : _rl; int ru = _wt.isLeft() ? _cu : _ru; _nnz = _ret.recomputeNonZeros(rl, ru-1, 0, _ret.getNumColumns()-1); return null; } /** * For wdivmm right. * @return */ public long getPartialNnz(){ return _nnz; } } private static class MatrixMultWCeTask implements ScalarResultTask { private MatrixBlock _mW = null; private MatrixBlock _mU = null; private MatrixBlock _mV = null; private MatrixBlock _ret = null; private WCeMMType _wt = null; private int _rl = -1; private int _ru = -1; protected MatrixMultWCeTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, WCeMMType wt, int rl, int ru) throws DMLRuntimeException { _mW = mW; _mU = mU; _mV = mV; _wt = wt; _rl = rl; _ru = ru; //allocate local result for partial aggregation _ret = new MatrixBlock(1, 1, false); _ret.allocateDenseBlock(); } @Override public Object call() throws DMLRuntimeException { //core weighted div mm computation if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() ) matrixMultWCeMMDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru); else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock()) matrixMultWCeMMSparseDense(_mW, _mU, _mV, _ret, _wt, _rl, _ru); else matrixMultWCeMMGeneric(_mW, _mU, _mV, _ret, _wt, _rl, _ru); return null; } @Override public double getScalarResult() { return _ret.quickGetValue(0, 0); } } }