/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops; import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.AggBinaryOp.SparkAggType; import org.apache.sysml.hops.Hop.MultiThreadedHop; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.lops.Aggregate; import org.apache.sysml.lops.Aggregate.OperationTypes; import org.apache.sysml.lops.Binary; import org.apache.sysml.lops.Group; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.PartialAggregate; import org.apache.sysml.lops.PartialAggregate.DirectionTypes; import org.apache.sysml.lops.TernaryAggregate; import org.apache.sysml.lops.UAggOuterChain; import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; /* Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum * Properties: * Symbol: +, min, max, ... * 1 Operand * * Semantic: generate indices, align, aggregate */ public class AggUnaryOp extends Hop implements MultiThreadedHop { private static final boolean ALLOW_UNARYAGG_WO_FINAL_AGG = true; private AggOp _op; private Direction _direction; private int _maxNumThreads = -1; //-1 for unlimited private AggUnaryOp() { //default constructor for clone } public AggUnaryOp(String l, DataType dt, ValueType vt, AggOp o, Direction idx, Hop inp) { super(l, dt, vt); _op = o; _direction = idx; getInput().add(0, inp); inp.getParent().add(this); } public AggOp getOp() { return _op; } public void setOp(AggOp op) { _op = op; } public Direction getDirection() { return _direction; } public void setDirection(Direction direction) { _direction = direction; } @Override public void setMaxNumThreads( int k ) { _maxNumThreads = k; } @Override public int getMaxNumThreads() { return _maxNumThreads; } @Override public Lop constructLops() throws HopsException, LopsException { //return already created lops if( getLops() != null ) return getLops(); try { ExecType et = optFindExecType(); Hop input = getInput().get(0); if ( et == ExecType.CP ) { Lop agg1 = null; if( isTernaryAggregateRewriteApplicable(et) ) { agg1 = constructLopsTernaryAggregateRewrite(et); } else if( isUnaryAggregateOuterCPRewriteApplicable() ) { OperationTypes op = HopsAgg2Lops.get(_op); DirectionTypes dir = HopsDirection2Lops.get(_direction); BinaryOp binput = (BinaryOp)getInput().get(0); agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP); PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } else { //general case int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET)) { // Only implemented methods for GPU if ( (_op == AggOp.SUM && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.SUM_SQ && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.MAX && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.MIN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.MEAN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.VAR && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.PROD && (_direction == Direction.RowCol))){ et = ExecType.GPU; k = 1; } } agg1 = new PartialAggregate(input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k); } setOutputDimensions(agg1); setLineNumbers(agg1); setLops(agg1); if (getDataType() == DataType.SCALAR) { agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz()); } } else if( et == ExecType.MR ) { OperationTypes op = HopsAgg2Lops.get(_op); DirectionTypes dir = HopsDirection2Lops.get(_direction); //unary aggregate operation Lop transform1 = null; if( isUnaryAggregateOuterRewriteApplicable() ) { BinaryOp binput = (BinaryOp)getInput().get(0); transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.MR); PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); } else //default { transform1 = new PartialAggregate(input.constructLops(), op, dir, DataType.MATRIX, getValueType()); ((PartialAggregate) transform1).setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock()); } setLineNumbers(transform1); //aggregation if required Lop aggregate = null; Group group1 = null; Aggregate agg1 = null; if( requiresAggregation(input, _direction) || transform1 instanceof UAggOuterChain ) { group1 = new Group(transform1, Group.OperationTypes.Sort, DataType.MATRIX, getValueType()); group1.getOutputParameters().setDimensions(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz()); setLineNumbers(group1); agg1 = new Aggregate(group1, HopsAgg2Lops.get(_op), DataType.MATRIX, getValueType(), et); agg1.getOutputParameters().setDimensions(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz()); agg1.setupCorrectionLocation(PartialAggregate.getCorrectionLocation(op,dir)); setLineNumbers(agg1); aggregate = agg1; } else { ((PartialAggregate) transform1).setDropCorrection(); aggregate = transform1; } setLops(aggregate); //cast if required if (getDataType() == DataType.SCALAR) { // Set the dimensions of PartialAggregate LOP based on the // direction in which aggregation is performed PartialAggregate.setDimensionsBasedOnDirection(transform1, input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); if( group1 != null && agg1 != null ) { //if aggregation required group1.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz()); agg1.getOutputParameters().setDimensions(1, 1, input.getRowsInBlock(), input.getColsInBlock(), getNnz()); } UnaryCP unary1 = new UnaryCP( aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } else if( et == ExecType.SPARK ) { OperationTypes op = HopsAgg2Lops.get(_op); DirectionTypes dir = HopsDirection2Lops.get(_direction); //unary aggregate if( isTernaryAggregateRewriteApplicable(et) ) { Lop aggregate = constructLopsTernaryAggregateRewrite(et); setOutputDimensions(aggregate); //0x0 (scalar) setLineNumbers(aggregate); setLops(aggregate); } else if( isUnaryAggregateOuterSPRewriteApplicable() ) { BinaryOp binput = (BinaryOp)getInput().get(0); Lop transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.SPARK); PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); setLineNumbers(transform1); setLops(transform1); if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(transform1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } else //default { boolean needAgg = requiresAggregation(input, _direction); SparkAggType aggtype = getSparkUnaryAggregationType(needAgg); PartialAggregate aggregate = new PartialAggregate(input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), DataType.MATRIX, getValueType(), aggtype, et); aggregate.setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock()); setLineNumbers(aggregate); setLops(aggregate); if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } } } catch (Exception e) { throw new HopsException(this.printErrorLocation() + "In AggUnary Hop, error constructing Lops " , e); } //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); //return created lops return getLops(); } @Override public String getOpString() { //ua - unary aggregate, for consistency with runtime String s = "ua(" + HopsAgg2String.get(_op) + HopsDirection2String.get(_direction) + ")"; return s; } @Override public boolean allowsAllExecTypes() { return true; } @Override protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) { double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) { //default: no additional memory required double val = 0; double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); switch( _op ) //see MatrixAggLib for runtime operations { case MAX: case MIN: //worst-case: column-wise, sparse (temp int count arrays) if( _direction == Direction.Col ) val = dim2 * OptimizerUtils.INT_SIZE; break; case SUM: case SUM_SQ: //worst-case correction LASTROW / LASTCOLUMN if( _direction == Direction.Col ) //(potentially sparse) val = OptimizerUtils.estimateSizeExactSparsity(1, dim2, sparsity); else if( _direction == Direction.Row ) //(always dense) val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0); break; case MEAN: //worst-case correction LASTTWOROWS / LASTTWOCOLUMNS if( _direction == Direction.Col ) //(potentially sparse) val = OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity); else if( _direction == Direction.Row ) //(always dense) val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0); break; case VAR: //worst-case correction LASTFOURROWS / LASTFOURCOLUMNS if( _direction == Direction.Col ) //(potentially sparse) val = OptimizerUtils.estimateSizeExactSparsity(4, dim2, sparsity); else if( _direction == Direction.Row ) //(always dense) val = OptimizerUtils.estimateSizeExactSparsity(dim1, 4, 1.0); break; case MAXINDEX: case MININDEX: Hop hop = getInput().get(0); if(isUnaryAggregateOuterCPRewriteApplicable()) val = 3 * OptimizerUtils.estimateSizeExactSparsity(1, hop._dim2, 1.0); else //worst-case correction LASTCOLUMN val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0); break; default: //no intermediate memory consumption val = 0; } return val; } @Override protected long[] inferOutputCharacteristics( MemoTable memo ) { long[] ret = null; Hop input = getInput().get(0); MatrixCharacteristics mc = memo.getAllInputStats(input); if( _direction == Direction.Col && mc.colsKnown() ) ret = new long[]{1, mc.getCols(), -1}; else if( _direction == Direction.Row && mc.rowsKnown() ) ret = new long[]{mc.getRows(), 1, -1}; return ret; } @Override protected ExecType optFindExecType() throws HopsException { checkAndSetForcedPlatform(); ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR; //forced / memory-based / threshold-based decision if( _etypeForced != null ) { _etype = _etypeForced; } else { if ( OptimizerUtils.isMemoryBasedOptLevel() ) { _etype = findExecTypeByMemEstimate(); } // Choose CP, if the input dimensions are below threshold or if the input is a vector else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() ) { _etype = ExecType.CP; } else { _etype = REMOTE; } //check for valid CP dimensions and matrix size checkAndSetInvalidCPDimsAndSize(); } //spark-specific decision refinement (execute unary aggregate w/ spark input and //single parent also in spark because it's likely cheap and reduces data transfer) if( _etype == ExecType.CP && _etypeForced != ExecType.CP && !(getInput().get(0) instanceof DataOp) //input is not checkpoint && (getInput().get(0).getParent().size()==1 //uagg is only parent, or || !requiresAggregation(getInput().get(0), _direction)) //w/o agg && getInput().get(0).optFindExecType() == ExecType.SPARK ) { //pull unary aggregate into spark _etype = ExecType.SPARK; } //mark for recompile (forever) if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE ) { setRequiresRecompile(); } return _etype; } private boolean requiresAggregation( Hop input, Direction dir ) { if( !ALLOW_UNARYAGG_WO_FINAL_AGG ) return false; //customization not allowed boolean noAggRequired = ( input.getDim1()>1 && input.getDim1()<=input.getRowsInBlock() && dir==Direction.Col ) //e.g., colSums(X) with nrow(X)<=1000 ||( input.getDim2()>1 && input.getDim2()<=input.getColsInBlock() && dir==Direction.Row ); //e.g., rowSums(X) with ncol(X)<=1000 return !noAggRequired; } private SparkAggType getSparkUnaryAggregationType( boolean agg ) { if( !agg ) return SparkAggType.NONE; if( getDataType()==DataType.SCALAR //in case of scalars the block dims are not set || dimsKnown() && getDim1()<=getRowsInBlock() && getDim2()<=getColsInBlock() ) return SparkAggType.SINGLE_BLOCK; else return SparkAggType.MULTI_BLOCK; } private boolean isTernaryAggregateRewriteApplicable(ExecType et) throws HopsException { boolean ret = false; //currently we support only sum over binary multiply but potentially //it can be generalized to any RC aggregate over two common binary operations if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && _op == AggOp.SUM && (_direction == Direction.RowCol || _direction == Direction.Col) ) { Hop input1 = getInput().get(0); if( input1.getParent().size() == 1 && //sum single consumer input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it. && input1.optFindExecType() != ExecType.MR) { Hop input11 = input1.getInput().get(0); Hop input12 = input1.getInput().get(1); if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) { //ternary, arbitrary matrices but no mv/outer operations. ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils.isEqualSize(input12, input1); } else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) { //ternary, arbitrary matrices but no mv/outer operations. ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils.isEqualSize(input11, input1); } else { //binary, arbitrary matrices but no mv/outer operations. ret = HopRewriteUtils.isEqualSize(input11, input12); } } } return ret; } private static boolean isCompareOperator(OpOp2 opOp2) { return (opOp2 == OpOp2.LESS || opOp2 == OpOp2.LESSEQUAL || opOp2 == OpOp2.GREATER || opOp2 == OpOp2.GREATEREQUAL || opOp2 == OpOp2.EQUAL || opOp2 == OpOp2.NOTEQUAL); } private boolean isUnaryAggregateOuterRewriteApplicable() { boolean ret = false; Hop input = getInput().get(0); if( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() ) { //for special cases, we need to hold the broadcast twice in order to allow for //an efficient binary search over a plain java array double factor = (isCompareOperator(((BinaryOp)input).getOp()) && (_direction == Direction.Row || _direction == Direction.Col || _direction == Direction.RowCol) && (_op == AggOp.SUM)) ? 2.0 : 1.0; factor += (isCompareOperator(((BinaryOp)input).getOp()) && (_direction == Direction.Row || _direction == Direction.Col) && (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX)) ? 1.0 : 0.0; //note: memory constraint only needs to take the rhs into account because the output //is guaranteed to be an aggregate of <=16KB Hop right = input.getInput().get(1); if( (right.dimsKnown() && factor*OptimizerUtils.estimateSize(right.getDim1(), right.getDim2()) < OptimizerUtils.getRemoteMemBudgetMap(true)) //dims known and estimate fits ||(!right.dimsKnown() && factor*right.getOutputMemEstimate() < OptimizerUtils.getRemoteMemBudgetMap(true)))//dims unknown but worst-case estimate fits { ret = true; } } return ret; } /** * This will check if there is sufficient memory locally (twice the size of second matrix, for original and sort data), and remotely (size of second matrix (sorted data)). * @return true if sufficient memory */ private boolean isUnaryAggregateOuterSPRewriteApplicable() { boolean ret = false; Hop input = getInput().get(0); if( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() ) { //note: both cases (partitioned matrix, and sorted double array), require to //fit the broadcast twice into the local memory budget. Also, the memory //constraint only needs to take the rhs into account because the output is //guaranteed to be an aggregate of <=16KB Hop right = input.getInput().get(1); double size = right.dimsKnown() ? OptimizerUtils.estimateSize(right.getDim1(), right.getDim2()) : //dims known and estimate fits right.getOutputMemEstimate(); //dims unknown but worst-case estimate fits if(_op == AggOp.MAXINDEX || _op == AggOp.MININDEX){ double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget(); double memBudgetLocal = OptimizerUtils.getLocalMemBudget(); //basic requirement: the broadcast needs to to fit twice in the remote broadcast memory //and local memory budget because we have to create a partitioned broadcast //memory and hand it over to the spark context as in-memory object ret = ( 2*size < memBudgetExec && 2*size < memBudgetLocal ); } else { if( OptimizerUtils.checkSparkBroadcastMemoryBudget(size) ) { ret = true; } } } return ret; } /** * This will check if this is one of the operator from supported LibMatrixOuterAgg library. * It needs to be Outer, aggregator type SUM, RowIndexMin, RowIndexMax and 6 operators <, <=, >, >=, == and != * * * @return true if unary aggregate outer */ private boolean isUnaryAggregateOuterCPRewriteApplicable() { boolean ret = false; Hop input = getInput().get(0); if(( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() ) && (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX || _op == AggOp.SUM) && (isCompareOperator(((BinaryOp)input).getOp()))) ret = true; return ret; } private Lop constructLopsTernaryAggregateRewrite(ExecType et) throws HopsException, LopsException { Hop input1 = getInput().get(0); Hop input11 = input1.getInput().get(0); Hop input12 = input1.getInput().get(1); Lop in1 = null; Lop in2 = null; Lop in3 = null; if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) { in1 = input11.getInput().get(0).constructLops(); in2 = input11.getInput().get(1).constructLops(); in3 = input12.constructLops(); } else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) { in1 = input11.constructLops(); in2 = input12.getInput().get(0).constructLops(); in3 = input12.getInput().get(1).constructLops(); } else { in1 = input11.constructLops(); in2 = input12.constructLops(); in3 = new LiteralOp(1).constructLops(); } //create new ternary aggregate operator int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ); // The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM // Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice. ExecType et_input = input1.optFindExecType(); DirectionTypes dir = HopsDirection2Lops.get(_direction); return new TernaryAggregate(in1, in2, in3, Aggregate.OperationTypes.KahanSum, Binary.OperationTypes.MULTIPLY, dir, getDataType(), ValueType.DOUBLE, et_input, k); } @Override public void refreshSizeInformation() { if (getDataType() != DataType.SCALAR) { Hop input = getInput().get(0); if ( _direction == Direction.Col ) //colwise computations { setDim1(1); setDim2(input.getDim2()); } else if ( _direction == Direction.Row ) { setDim1(input.getDim1()); setDim2(1); } } } @Override public boolean isTransposeSafe() { boolean ret = (_direction == Direction.RowCol) && //full aggregate (_op == AggOp.SUM || _op == AggOp.SUM_SQ || //valid aggregration functions _op == AggOp.MIN || _op == AggOp.MAX || _op == AggOp.PROD || _op == AggOp.MEAN || _op == AggOp.VAR); //note: trace and maxindex are not transpose-safe. return ret; } @Override public Object clone() throws CloneNotSupportedException { AggUnaryOp ret = new AggUnaryOp(); //copy generic attributes ret.clone(this, false); //copy specific attributes ret._op = _op; ret._direction = _direction; ret._maxNumThreads = _maxNumThreads; return ret; } @Override public boolean compare( Hop that ) { if( !(that instanceof AggUnaryOp) ) return false; AggUnaryOp that2 = (AggUnaryOp)that; return ( _op == that2._op && _direction == that2._direction && _maxNumThreads == that2._maxNumThreads && getInput().get(0) == that2.getInput().get(0)); } }