/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.codegen;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.UtilFunctions;
public abstract class SpoofRowwise extends SpoofOperator
{
private static final long serialVersionUID = 6242910797139642998L;
private static final long PAR_NUMCELL_THRESHOLD = 1024*1024; //Min 1M elements
public enum RowType {
NO_AGG, //no aggregation
ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v)
COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X)
COL_AGG_T; //transposed col aggregation (e.g., t(X) %*% y)
public boolean isColumnAgg() {
return (this == COL_AGG || this == COL_AGG_T);
}
}
protected final RowType _type;
protected final int _reqVectMem;
public SpoofRowwise(RowType type, int reqVectMem) {
_type = type;
_reqVectMem = reqVectMem;
}
public RowType getRowType() {
return _type;
}
public int getNumIntermediates() {
return _reqVectMem;
}
@Override
public String getSpoofType() {
return "RA" + getClass().getName().split("\\.")[1];
}
@Override
public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out)
throws DMLRuntimeException
{
execute(inputs, scalarObjects, out, true, false);
}
public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, boolean allocTmp, boolean aggIncr)
throws DMLRuntimeException
{
//sanity check
if( inputs==null || inputs.size() < 1 || out==null )
throw new RuntimeException("Invalid input arguments.");
//result allocation and preparations
final int m = inputs.get(0).getNumRows();
final int n = inputs.get(0).getNumColumns();
if( !aggIncr || !out.isAllocated() )
allocateOutputMatrix(m, n, out);
double[] c = out.getDenseBlock();
//input preparation
double[][] b = prepInputMatrices(inputs);
double[] scalars = prepInputScalars(scalarObjects);
//setup thread-local memory if necessary
if( allocTmp )
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, n);
//core sequential execute
if( !inputs.get(0).isInSparseFormat() )
executeDense(inputs.get(0).getDenseBlock(), b, scalars, c, n, 0, m);
else
executeSparse(inputs.get(0).getSparseBlock(), b, scalars, c, n, 0, m);
//post-processing
if( allocTmp )
LibSpoofPrimitives.cleanupThreadLocalMemory();
out.recomputeNonZeros();
out.examSparsity();
}
@Override
public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k)
throws DMLRuntimeException
{
//redirect to serial execution
if( k <= 1 || (long)inputs.get(0).getNumRows()*inputs.get(0).getNumColumns()<PAR_NUMCELL_THRESHOLD ) {
execute(inputs, scalarObjects, out);
return;
}
//sanity check
if( inputs==null || inputs.size() < 1 || out==null )
throw new RuntimeException("Invalid input arguments.");
//result allocation and preparations
final int m = inputs.get(0).getNumRows();
final int n = inputs.get(0).getNumColumns();
allocateOutputMatrix(m, n, out);
//input preparation
double[][] b = prepInputMatrices(inputs);
double[] scalars = prepInputScalars(scalarObjects);
//core parallel execute
ExecutorService pool = Executors.newFixedThreadPool( k );
int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k);
int blklen = (int)(Math.ceil((double)m/nk));
try
{
if( _type.isColumnAgg() ) {
//execute tasks
ArrayList<ParColAggTask> tasks = new ArrayList<ParColAggTask>();
for( int i=0; i<nk & i*blklen<m; i++ )
tasks.add(new ParColAggTask(inputs.get(0), b, scalars, n, i*blklen, Math.min((i+1)*blklen, m)));
List<Future<double[]>> taskret = pool.invokeAll(tasks);
//aggregate partial results
for( Future<double[]> task : taskret )
LibMatrixMult.vectAdd(task.get(), out.getDenseBlock(), 0, 0, n);
out.recomputeNonZeros();
}
else {
//execute tasks
ArrayList<ParExecTask> tasks = new ArrayList<ParExecTask>();
for( int i=0; i<nk & i*blklen<m; i++ )
tasks.add(new ParExecTask(inputs.get(0), b, out, scalars, n, i*blklen, Math.min((i+1)*blklen, m)));
List<Future<Long>> taskret = pool.invokeAll(tasks);
//aggregate nnz, no need to aggregate results
long nnz = 0;
for( Future<Long> task : taskret )
nnz += task.get();
out.setNonZeros(nnz);
}
pool.shutdown();
out.examSparsity();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}
private void allocateOutputMatrix(int m, int n, MatrixBlock out) {
switch( _type ) {
case NO_AGG: out.reset(m, n, false); break;
case ROW_AGG: out.reset(m, 1, false); break;
case COL_AGG: out.reset(1, n, false); break;
case COL_AGG_T: out.reset(n, 1, false); break;
}
out.allocateDenseBlock();
}
private void executeDense(double[] a, double[][] b, double[] scalars, double[] c, int n, int rl, int ru)
{
if( a == null )
return;
for( int i=rl, aix=rl*n; i<ru; i++, aix+=n ) {
//call generated method
genexecRowDense( a, aix, b, scalars, c, n, i );
}
}
private void executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int n, int rl, int ru)
{
if( sblock == null )
return;
for( int i=rl; i<ru; i++ ) {
if( !sblock.isEmpty(i) ) {
double[] avals = sblock.values(i);
int[] aix = sblock.indexes(i);
int apos = sblock.pos(i);
int alen = sblock.size(i);
//call generated method
genexecRowSparse(avals, aix, apos, b, scalars, c, alen, i);
}
}
}
//methods to be implemented by generated operators of type SpoofRowAggrgate
protected abstract void genexecRowDense( double[] a, int ai, double[][] b, double[] scalars, double[] c, int len, int rowIndex );
protected abstract void genexecRowSparse( double[] avals, int[] aix, int ai, double[][] b, double[] scalars, double[] c, int len, int rowIndex );
/**
* Task for multi-threaded column aggregation operations.
*/
private class ParColAggTask implements Callable<double[]>
{
private final MatrixBlock _a;
private final double[][] _b;
private final double[] _scalars;
private final int _clen;
private final int _rl;
private final int _ru;
protected ParColAggTask( MatrixBlock a, double[][] b, double[] scalars, int clen, int rl, int ru ) {
_a = a;
_b = b;
_scalars = scalars;
_clen = clen;
_rl = rl;
_ru = ru;
}
@Override
public double[] call() throws DMLRuntimeException {
//allocate vector intermediates and partial output
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen);
double[] c = new double[_clen];
if( !_a.isInSparseFormat() )
executeDense(_a.getDenseBlock(), _b, _scalars, c, _clen, _rl, _ru);
else
executeSparse(_a.getSparseBlock(), _b, _scalars, c, _clen, _rl, _ru);
LibSpoofPrimitives.cleanupThreadLocalMemory();
return c;
}
}
/**
* Task for multi-threaded execution with no or row aggregation.
*/
private class ParExecTask implements Callable<Long>
{
private final MatrixBlock _a;
private final double[][] _b;
private final MatrixBlock _c;
private final double[] _scalars;
private final int _clen;
private final int _rl;
private final int _ru;
protected ParExecTask( MatrixBlock a, double[][] b, MatrixBlock c, double[] scalars, int clen, int rl, int ru ) {
_a = a;
_b = b;
_c = c;
_scalars = scalars;
_clen = clen;
_rl = rl;
_ru = ru;
}
@Override
public Long call() throws DMLRuntimeException {
//allocate vector intermediates
LibSpoofPrimitives.setupThreadLocalMemory(_reqVectMem, _clen);
if( !_a.isInSparseFormat() )
executeDense(_a.getDenseBlock(), _b, _scalars, _c.getDenseBlock(), _clen, _rl, _ru);
else
executeSparse(_a.getSparseBlock(), _b, _scalars, _c.getDenseBlock(), _clen, _rl, _ru);
LibSpoofPrimitives.cleanupThreadLocalMemory();
//maintain nnz for row partition
return _c.recomputeNonZeros(_rl, _ru-1, 0, _c.getNumColumns()-1);
}
}
}