/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.codegen;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.UtilFunctions;
public abstract class SpoofMultiAggregate extends SpoofOperator implements Serializable
{
private static final long serialVersionUID = -6164871955591089349L;
private static final long PAR_NUMCELL_THRESHOLD = 1024*1024; //Min 1M elements
private final AggOp[] _aggOps;
public SpoofMultiAggregate(AggOp... aggOps) {
_aggOps = aggOps;
}
public AggOp[] getAggOps() {
return _aggOps;
}
@Override
public String getSpoofType() {
return "MA" + getClass().getName().split("\\.")[1];
}
@Override
public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out)
throws DMLRuntimeException
{
execute(inputs, scalarObjects, out, 1);
}
@Override
public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k)
throws DMLRuntimeException
{
//sanity check
if( inputs==null || inputs.size() < 1 )
throw new RuntimeException("Invalid input arguments.");
if( inputs.get(0).getNumRows()*inputs.get(0).getNumColumns()<PAR_NUMCELL_THRESHOLD ) {
k = 1; //serial execution
}
//result allocation and preparations
out.reset(1, _aggOps.length, false);
out.allocateDenseBlock();
double[] c = out.getDenseBlock();
setInitialOutputValues(c);
//input preparation
double[][] b = prepInputMatrices(inputs);
double[] scalars = prepInputScalars(scalarObjects);
final int m = inputs.get(0).getNumRows();
final int n = inputs.get(0).getNumColumns();
if( k <= 1 ) //SINGLE-THREADED
{
if( !inputs.get(0).isInSparseFormat() )
executeDense(inputs.get(0).getDenseBlock(), b, scalars, c, m, n, 0, m);
else
executeSparse(inputs.get(0).getSparseBlock(), b, scalars, c, m, n, 0, m);
}
else //MULTI-THREADED
{
try {
ExecutorService pool = Executors.newFixedThreadPool( k );
ArrayList<ParAggTask> tasks = new ArrayList<ParAggTask>();
int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k);
int blklen = (int)(Math.ceil((double)m/nk));
for( int i=0; i<nk & i*blklen<m; i++ )
tasks.add(new ParAggTask(inputs.get(0), b, scalars, m, n, i*blklen, Math.min((i+1)*blklen, m)));
//execute tasks
List<Future<double[]>> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial results
ArrayList<double[]> pret = new ArrayList<double[]>();
for( Future<double[]> task : taskret )
pret.add(task.get());
aggregatePartialResults(c, pret);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}
//post-processing
out.recomputeNonZeros();
out.examSparsity();
}
private void executeDense(double[] a, double[][] b, double[] scalars, double[] c, int m, int n, int rl, int ru) throws DMLRuntimeException
{
//core dense aggregation operation
for( int i=rl, ix=rl*n; i<ru; i++ ) {
for( int j=0; j<n; j++, ix++ ) {
double in = (a != null) ? a[ix] : 0;
genexec( in, b, scalars, c, m, n, i, j );
}
}
}
private void executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int m, int n, int rl, int ru)
throws DMLRuntimeException
{
//core dense aggregation operation
for( int i=rl; i<ru; i++ )
for( int j=0; j<n; j++ ) {
double in = (sblock != null) ? sblock.get(i, j) : 0;
genexec( in, b, scalars, c, m, n, i, j );
}
}
protected abstract void genexec( double a, double[][] b, double[] scalars, double[] c, int m, int n, int rowIndex, int colIndex);
private void setInitialOutputValues(double[] c) {
for( int k=0; k<_aggOps.length; k++ )
c[k] = getInitialValue(_aggOps[k]);
}
public static double getInitialValue(AggOp aggop) {
switch( aggop ) {
case SUM:
case SUM_SQ: return 0;
case MIN: return Double.MAX_VALUE;
case MAX: return -Double.MAX_VALUE;
}
return 0;
}
private void aggregatePartialResults(double[] c, ArrayList<double[]> pret)
throws DMLRuntimeException
{
ValueFunction[] vfun = getAggFunctions(_aggOps);
for( int k=0; k<_aggOps.length; k++ ) {
if( vfun[k] instanceof KahanFunction ) {
KahanObject kbuff = new KahanObject(0, 0);
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
for(double[] tmp : pret)
kplus.execute2(kbuff, tmp[k]);
c[k] = kbuff._sum;
}
else {
for(double[] tmp : pret)
c[k] = vfun[k].execute(c[k], tmp[k]);
}
}
}
public static void aggregatePartialResults(AggOp[] aggOps, MatrixBlock c, MatrixBlock b)
throws DMLRuntimeException
{
ValueFunction[] vfun = getAggFunctions(aggOps);
for( int k=0; k< aggOps.length; k++ ) {
if( vfun[k] instanceof KahanFunction ) {
KahanObject kbuff = new KahanObject(c.quickGetValue(0, k), 0);
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
kplus.execute2(kbuff, b.quickGetValue(0, k));
c.quickSetValue(0, k, kbuff._sum);
}
else {
double cval = c.quickGetValue(0, k);
double bval = b.quickGetValue(0, k);
c.quickSetValue(0, k, vfun[k].execute(cval, bval));
}
}
}
public static ValueFunction[] getAggFunctions(AggOp[] aggOps) {
ValueFunction[] fun = new ValueFunction[aggOps.length];
for( int i=0; i<aggOps.length; i++ ) {
switch( aggOps[i] ) {
case SUM: fun[i] = KahanPlus.getKahanPlusFnObject(); break;
case SUM_SQ: fun[i] = KahanPlusSq.getKahanPlusSqFnObject(); break;
case MIN: fun[i] = Builtin.getBuiltinFnObject(BuiltinCode.MIN); break;
case MAX: fun[i] = Builtin.getBuiltinFnObject(BuiltinCode.MAX); break;
default:
throw new RuntimeException("Unsupported "
+ "aggregation type: "+aggOps[i].name());
}
}
return fun;
}
private class ParAggTask implements Callable<double[]>
{
private final MatrixBlock _a;
private final double[][] _b;
private final double[] _scalars;
private final int _rlen;
private final int _clen;
private final int _rl;
private final int _ru;
protected ParAggTask( MatrixBlock a, double[][] b, double[] scalars,
int rlen, int clen, int rl, int ru ) {
_a = a;
_b = b;
_scalars = scalars;
_rlen = rlen;
_clen = clen;
_rl = rl;
_ru = ru;
}
@Override
public double[] call() throws DMLRuntimeException {
double[] c = new double[_aggOps.length];
setInitialOutputValues(c);
if( !_a.isInSparseFormat() )
executeDense(_a.getDenseBlock(), _b, _scalars, c, _rlen, _clen, _rl, _ru);
else
executeSparse(_a.getSparseBlock(), _b, _scalars, c, _rlen, _clen, _rl, _ru);
return c;
}
}
}