/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.codegen; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.functionobjects.KahanFunction; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.functionobjects.KahanPlusSq; import org.apache.sysml.runtime.functionobjects.ValueFunction; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; import org.apache.sysml.runtime.util.UtilFunctions; public abstract class SpoofCellwise extends SpoofOperator implements Serializable { private static final long serialVersionUID = 3442528770573293590L; private static final long PAR_NUMCELL_THRESHOLD = 1024*1024; //Min 1M elements public enum CellType { NO_AGG, FULL_AGG, ROW_AGG, } //redefinition of Hop.AggOp for cleaner imports in generate class public enum AggOp { SUM, SUM_SQ, MIN, MAX, } private final CellType _type; private final AggOp _aggOp; private final boolean _sparseSafe; public SpoofCellwise(CellType type, AggOp aggOp, boolean sparseSafe) { _type = type; _aggOp = aggOp; _sparseSafe = sparseSafe; } public CellType getCellType() { return _type; } public AggOp getAggOp() { return _aggOp; } public boolean isSparseSafe() { return _sparseSafe; } @Override public String getSpoofType() { return "Cell" + getClass().getName().split("\\.")[1]; } private ValueFunction getAggFunction() { switch( _aggOp ) { case SUM: return KahanPlus.getKahanPlusFnObject(); case SUM_SQ: return KahanPlusSq.getKahanPlusSqFnObject(); case MIN: return Builtin.getBuiltinFnObject(BuiltinCode.MIN); case MAX: return Builtin.getBuiltinFnObject(BuiltinCode.MAX); default: throw new RuntimeException("Unsupported " + "aggregation type: "+_aggOp.name()); } } @Override public ScalarObject execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, int k) throws DMLRuntimeException { //sanity check if( inputs==null || inputs.size() < 1 ) throw new RuntimeException("Invalid input arguments."); if( inputs.get(0).getNumRows()*inputs.get(0).getNumColumns()<PAR_NUMCELL_THRESHOLD ) { k = 1; //serial execution } //input preparation double[][] b = prepInputMatrices(inputs); double[] scalars = prepInputScalars(scalarObjects); final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); //sparse safe check boolean sparseSafe = isSparseSafe() || (b.length == 0 && genexec( 0, b, scalars, m, n, 0, 0 ) == 0); double ret = 0; if( k <= 1 ) //SINGLE-THREADED { ret = ( !inputs.get(0).isInSparseFormat() ) ? executeDenseAndAgg(inputs.get(0).getDenseBlock(), b, scalars, m, n, sparseSafe, 0, m) : executeSparseAndAgg(inputs.get(0).getSparseBlock(), b, scalars, m, n, sparseSafe, 0, m); } else //MULTI-THREADED { try { ExecutorService pool = Executors.newFixedThreadPool( k ); ArrayList<ParAggTask> tasks = new ArrayList<ParAggTask>(); int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k); int blklen = (int)(Math.ceil((double)m/nk)); for( int i=0; i<nk & i*blklen<m; i++ ) tasks.add(new ParAggTask(inputs.get(0), b, scalars, m, n, sparseSafe, i*blklen, Math.min((i+1)*blklen, m))); //execute tasks List<Future<Double>> taskret = pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results ValueFunction vfun = getAggFunction(); if( vfun instanceof KahanFunction ) { KahanObject kbuff = new KahanObject(0, 0); KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); for( Future<Double> task : taskret ) kplus.execute2(kbuff, task.get()); ret = kbuff._sum; } else { for( Future<Double> task : taskret ) ret = vfun.execute(ret, task.get()); } } catch(Exception ex) { throw new DMLRuntimeException(ex); } } //correction for min/max if( (_aggOp == AggOp.MIN || _aggOp == AggOp.MAX) && sparseSafe && inputs.get(0).getNonZeros()<inputs.get(0).getNumRows()*inputs.get(0).getNumColumns() ) ret = getAggFunction().execute(ret, 0); //unseen 0 might be max or min value return new DoubleObject(ret); } @Override public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) throws DMLRuntimeException { execute(inputs, scalarObjects, out, 1); } @Override public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k) throws DMLRuntimeException { //sanity check if( inputs==null || inputs.size() < 1 || out==null ) throw new RuntimeException("Invalid input arguments."); if( inputs.get(0).getNumRows()*inputs.get(0).getNumColumns()<PAR_NUMCELL_THRESHOLD ) { k = 1; //serial execution } //result allocation and preparations out.reset(inputs.get(0).getNumRows(), _type == CellType.NO_AGG ? inputs.get(0).getNumColumns() : 1, false); out.allocateDenseBlock(); double[] c = out.getDenseBlock(); //input preparation double[][] b = prepInputMatrices(inputs); double[] scalars = prepInputScalars(scalarObjects); final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); //sparse safe check boolean sparseSafe = isSparseSafe() || (b.length == 0 && genexec( 0, b, scalars, m, n, 0, 0 ) == 0); long lnnz = 0; if( k <= 1 ) //SINGLE-THREADED { lnnz = (!inputs.get(0).isInSparseFormat()) ? executeDense(inputs.get(0).getDenseBlock(), b, scalars, c, m, n, sparseSafe, 0, m) : executeSparse(inputs.get(0).getSparseBlock(), b, scalars, c, m, n, sparseSafe, 0, m); } else //MULTI-THREADED { try { ExecutorService pool = Executors.newFixedThreadPool( k ); ArrayList<ParExecTask> tasks = new ArrayList<ParExecTask>(); int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k); int blklen = (int)(Math.ceil((double)m/nk)); for( int i=0; i<nk & i*blklen<m; i++ ) tasks.add(new ParExecTask(inputs.get(0), b, scalars, c, m, n, sparseSafe, i*blklen, Math.min((i+1)*blklen, m))); //execute tasks List<Future<Long>> taskret = pool.invokeAll(tasks); pool.shutdown(); //aggregate nnz and error handling for( Future<Long> task : taskret ) lnnz += task.get(); } catch(Exception ex) { throw new DMLRuntimeException(ex); } } //post-processing out.setNonZeros(lnnz); out.examSparsity(); } private double executeDenseAndAgg(double[] a, double[][] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException { ValueFunction vfun = getAggFunction(); double ret = 0; //numerically stable aggregation for sum/sum_sq if( vfun instanceof KahanFunction ) { KahanObject kbuff = new KahanObject(0, 0); KahanFunction kplus = (KahanFunction) vfun; if( a == null && !sparseSafe ) { //empty for( int i=rl; i<ru; i++ ) for( int j=0; j<n; j++ ) kplus.execute2(kbuff, genexec( 0, b, scalars, m, n, i, j )); } else if( a != null ) { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) if( a[ix] != 0 || !sparseSafe) kplus.execute2(kbuff, genexec( a[ix], b, scalars, m, n, i, j )); } ret = kbuff._sum; } //safe aggregation for min/max w/ handling of zero entries //note: sparse safe with zero value as min/max handled outside else { ret = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; if( a == null && !sparseSafe ) { //empty for( int i=rl; i<ru; i++ ) for( int j=0; j<n; j++ ) ret = vfun.execute(ret, genexec( 0, b, scalars, m, n, i, j )); } else if( a != null ) { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) if( a[ix] != 0 || !sparseSafe) ret = vfun.execute(ret, genexec( a[ix], b, scalars, m, n, i, j )); } } return ret; } private long executeDense(double[] a, double[][] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException { long lnnz = 0; if( _type == CellType.NO_AGG ) { if( a == null && !sparseSafe ) { //empty //note: we can't determine sparse-safeness by executing the operator once //as the output might change with different row indices for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) { c[ix] = genexec( 0, b, scalars, m, n, i, j ); lnnz += (c[ix]!=0) ? 1 : 0; } } else if( a != null ) { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) for( int j=0; j<n; j++, ix++ ) if( a[ix] != 0 || !sparseSafe) { c[ix] = genexec( a[ix], b, scalars, m, n, i, j); lnnz += (c[ix]!=0) ? 1 : 0; } } } else if( _type == CellType.ROW_AGG ) { ValueFunction vfun = getAggFunction(); if( vfun instanceof KahanFunction ) { KahanObject kbuff = new KahanObject(0, 0); KahanFunction kplus = (KahanFunction) vfun; if( a == null && !sparseSafe ) { //empty for( int i=rl; i<ru; i++ ) { kbuff.set(0, 0); for( int j=0; j<n; j++ ) kplus.execute2(kbuff, genexec( 0, b, scalars, m, n, i, j )); lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; } } else if( a != null ) { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) { kbuff.set(0, 0); for( int j=0; j<n; j++, ix++ ) if( a[ix] != 0 || !sparseSafe) kplus.execute2(kbuff, genexec( a[ix], b, scalars, m, n, i, j )); lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; } } } else { double initialVal = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; if( a == null && !sparseSafe ) { //empty for( int i=rl; i<ru; i++ ) { double tmp = initialVal; for( int j=0; j<n; j++ ) tmp = vfun.execute(tmp, genexec( 0, b, scalars, m, n, i, j )); lnnz += ((c[i] = tmp)!=0) ? 1 : 0; } } else if( a != null ) { //general case for( int i=rl, ix=rl*n; i<ru; i++ ) { double tmp = initialVal; for( int j=0; j<n; j++, ix++ ) if( a[ix] != 0 || !sparseSafe) tmp = vfun.execute(tmp, genexec( a[ix], b, scalars, m, n, i, j )); if( sparseSafe && UtilFunctions.containsZero(a, ix-n, n) ) tmp = vfun.execute(tmp, 0); lnnz += ((c[i] = tmp)!=0) ? 1 : 0; } } } } return lnnz; } private double executeSparseAndAgg(SparseBlock sblock, double[][] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException { if( sparseSafe && sblock == null ) return 0; ValueFunction vfun = getAggFunction(); double ret = 0; //numerically stable aggregation for sum/sum_sq if( vfun instanceof KahanFunction ) { KahanObject kbuff = new KahanObject(0, 0); KahanFunction kplus = (KahanFunction) vfun; //note: sequential scan algorithm for both sparse-safe and -unsafe //in order to avoid binary search for sparse-unsafe for(int i=rl; i<ru; i++) { int lastj = -1; //handle non-empty rows if( sblock != null && !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); double[] avals = sblock.values(i); for(int k=apos; k<apos+alen; k++) { //process zeros before current non-zero if( !sparseSafe ) for(int j=lastj+1; j<aix[k]; j++) kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j)); //process current non-zero lastj = aix[k]; kplus.execute2(kbuff, genexec(avals[k], b, scalars, m, n, i, lastj)); } } //process empty rows or remaining zeros if( !sparseSafe ) for(int j=lastj+1; j<n; j++) kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j)); } ret = kbuff._sum; } //safe aggregation for min/max w/ handling of zero entries //note: sparse safe with zero value as min/max handled outside else { ret = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; ret = (sparseSafe && sblock.size() < (long)m*n) ? 0 : ret; //note: sequential scan algorithm for both sparse-safe and -unsafe //in order to avoid binary search for sparse-unsafe for(int i=rl; i<ru; i++) { int lastj = -1; //handle non-empty rows if( sblock != null && !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); double[] avals = sblock.values(i); for(int k=apos; k<apos+alen; k++) { //process zeros before current non-zero if( !sparseSafe ) for(int j=lastj+1; j<aix[k]; j++) ret = vfun.execute(ret, genexec(0, b, scalars, m, n, i, j)); //process current non-zero lastj = aix[k]; ret = vfun.execute(ret, genexec(avals[k], b, scalars, m, n, i, lastj)); } } //process empty rows or remaining zeros if( !sparseSafe ) for(int j=lastj+1; j<n; j++) ret = vfun.execute(ret, genexec(0, b, scalars, m, n, i, j)); } } return ret; } private long executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException { if( sparseSafe && sblock == null ) return 0; long lnnz = 0; if( _type == CellType.NO_AGG ) { //note: sequential scan algorithm for both sparse-safe and -unsafe //in order to avoid binary search for sparse-unsafe for(int i=rl, cix=rl*n; i<ru; i++, cix+=n) { int lastj = -1; //handle non-empty rows if( sblock != null && !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); double[] avals = sblock.values(i); for(int k=apos; k<apos+alen; k++) { //process zeros before current non-zero if( !sparseSafe ) for(int j=lastj+1; j<aix[k]; j++) lnnz += ((c[cix+j]=genexec(0, b, scalars, m, n, i, j))!=0)?1:0; //process current non-zero lastj = aix[k]; lnnz += ((c[cix+lastj]=genexec(avals[k], b, scalars, m, n, i, lastj))!=0)?1:0; } } //process empty rows or remaining zeros if( !sparseSafe ) for(int j=lastj+1; j<n; j++) lnnz += ((c[cix+j]=genexec(0, b, scalars, m, n, i, j))!=0)?1:0; } } else if( _type == CellType.ROW_AGG ) { ValueFunction vfun = getAggFunction(); if( vfun instanceof KahanFunction ) { KahanObject kbuff = new KahanObject(0, 0); KahanFunction kplus = (KahanFunction) vfun; //note: sequential scan algorithm for both sparse-safe and -unsafe //in order to avoid binary search for sparse-unsafe for(int i=rl; i<ru; i++) { kbuff.set(0, 0); int lastj = -1; //handle non-empty rows if( sblock != null && !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); double[] avals = sblock.values(i); for(int k=apos; k<apos+alen; k++) { //process zeros before current non-zero if( !sparseSafe ) for(int j=lastj+1; j<aix[k]; j++) kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j)); //process current non-zero lastj = aix[k]; kplus.execute2(kbuff, genexec(avals[k], b, scalars, m, n, i, lastj)); } } //process empty rows or remaining zeros if( !sparseSafe ) for(int j=lastj+1; j<n; j++) kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j)); lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; } } else { double initialVal = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; //note: sequential scan algorithm for both sparse-safe and -unsafe //in order to avoid binary search for sparse-unsafe for(int i=rl; i<ru; i++) { double tmp = (sparseSafe && sblock.size(i) < n) ? 0 : initialVal; int lastj = -1; //handle non-empty rows if( sblock != null && !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); double[] avals = sblock.values(i); for(int k=apos; k<apos+alen; k++) { //process zeros before current non-zero if( !sparseSafe ) for(int j=lastj+1; j<aix[k]; j++) tmp = vfun.execute(tmp, genexec(0, b, scalars, m, n, i, j)); //process current non-zero lastj = aix[k]; tmp = vfun.execute( tmp, genexec(avals[k], b, scalars, m, n, i, lastj)); } } //process empty rows or remaining zeros if( !sparseSafe ) for(int j=lastj+1; j<n; j++) tmp = vfun.execute(tmp, genexec(0, b, scalars, m, n, i, j)); lnnz += ((c[i] = tmp)!=0) ? 1 : 0; } } } return lnnz; } protected abstract double genexec( double a, double[][] b, double[] scalars, int m, int n, int rowIndex, int colIndex); private class ParAggTask implements Callable<Double> { private final MatrixBlock _a; private final double[][] _b; private final double[] _scalars; private final int _rlen; private final int _clen; private final boolean _safe; private final int _rl; private final int _ru; protected ParAggTask( MatrixBlock a, double[][] b, double[] scalars, int rlen, int clen, boolean sparseSafe, int rl, int ru ) { _a = a; _b = b; _scalars = scalars; _rlen = rlen; _clen = clen; _safe = sparseSafe; _rl = rl; _ru = ru; } @Override public Double call() throws DMLRuntimeException { return ( !_a.isInSparseFormat()) ? executeDenseAndAgg(_a.getDenseBlock(), _b, _scalars, _rlen, _clen, _safe, _rl, _ru) : executeSparseAndAgg(_a.getSparseBlock(), _b, _scalars, _rlen, _clen, _safe, _rl, _ru); } } private class ParExecTask implements Callable<Long> { private final MatrixBlock _a; private final double[][] _b; private final double[] _scalars; private final double[] _c; private final int _rlen; private final int _clen; private final boolean _safe; private final int _rl; private final int _ru; protected ParExecTask( MatrixBlock a, double[][] b, double[] scalars, double[] c, int rlen, int clen, boolean sparseSafe, int rl, int ru ) { _a = a; _b = b; _scalars = scalars; _c = c; _rlen = rlen; _clen = clen; _safe = sparseSafe; _rl = rl; _ru = ru; } @Override public Long call() throws DMLRuntimeException { return (!_a.isInSparseFormat()) ? executeDense(_a.getDenseBlock(), _b, _scalars, _c, _rlen, _clen, _safe, _rl, _ru) : executeSparse(_a.getSparseBlock(), _b, _scalars, _c, _rlen, _clen, _safe, _rl, _ru); } } }