/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.HashMap; import org.apache.commons.lang.ArrayUtils; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.FileFormatTypes; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.MemoTable; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Statement; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.UtilFunctions; public class HopRewriteUtils { public static boolean isValueTypeCast( OpOp1 op ) { return ( op == OpOp1.CAST_AS_BOOLEAN || op == OpOp1.CAST_AS_INT || op == OpOp1.CAST_AS_DOUBLE ); } ////////////////////////////////// // literal handling public static boolean getBooleanValue( LiteralOp op ) throws HopsException { switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue() != 0; case INT: return op.getLongValue() != 0; case BOOLEAN: return op.getBooleanValue(); default: throw new HopsException("Invalid boolean value: "+op.getValueType()); } } public static boolean getBooleanValueSafe( LiteralOp op ) { try { switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue() != 0; case INT: return op.getLongValue() != 0; case BOOLEAN: return op.getBooleanValue(); default: throw new HopsException("Invalid boolean value: "+op.getValueType()); } } catch(Exception ex){ //silently ignore error } return false; } public static double getDoubleValue( LiteralOp op ) throws HopsException { switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue(); case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: throw new HopsException("Invalid double value: "+op.getValueType()); } } public static double getDoubleValueSafe( LiteralOp op ) { try { switch( op.getValueType() ) { case DOUBLE: return op.getDoubleValue(); case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: throw new HopsException("Invalid double value: "+op.getValueType()); } } catch(Exception ex){ //silently ignore error } return Double.MAX_VALUE; } /** * Return the int value of a LiteralOp (as a long). * * Note: For comparisons, this is *only* to be used in situations * in which the value is absolutely guaranteed to be an integer. * Otherwise, a safer alternative is `getDoubleValue`. * * @param op literal operator * @return long value of literator op * @throws HopsException if HopsException occurs */ public static long getIntValue( LiteralOp op ) throws HopsException { switch( op.getValueType() ) { case DOUBLE: return UtilFunctions.toLong(op.getDoubleValue()); case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: throw new HopsException("Invalid int value: "+op.getValueType()); } } public static long getIntValueSafe( LiteralOp op ) { try { switch( op.getValueType() ) { case DOUBLE: return UtilFunctions.toLong(op.getDoubleValue()); case INT: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: throw new RuntimeException("Invalid int value: "+op.getValueType()); } } catch(Exception ex){ //silently ignore error } return Long.MAX_VALUE; } public static boolean isLiteralOfValue( Hop hop, double val ) { return (hop instanceof LiteralOp && (hop.getValueType()==ValueType.DOUBLE || hop.getValueType()==ValueType.INT) && getDoubleValueSafe((LiteralOp)hop)==val); } public static ScalarObject getScalarObject( LiteralOp op ) { try { return ScalarObjectFactory .createScalarObject(op.getValueType(), op); } catch(Exception ex) { throw new RuntimeException("Failed to create scalar object for constant. Continue.", ex); } } /////////////////////////////////// // hop dag transformations public static int getChildReferencePos( Hop parent, Hop child ) { return parent.getInput().indexOf(child); } public static void removeChildReference( Hop parent, Hop child ) { parent.getInput().remove( child ); child.getParent().remove( parent ); } public static void removeChildReferenceByPos( Hop parent, Hop child, int posChild ) { parent.getInput().remove( posChild ); child.getParent().remove( parent ); } public static void removeAllChildReferences( Hop parent ) { //remove parent reference from all childs for( Hop child : parent.getInput() ) child.getParent().remove(parent); //remove all child references parent.getInput().clear(); } public static void addChildReference( Hop parent, Hop child ) { parent.getInput().add( child ); child.getParent().add( parent ); } public static void addChildReference( Hop parent, Hop child, int pos ){ parent.getInput().add( pos, child ); child.getParent().add( parent ); } public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) { ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent()); for( Hop lparent : parents ) HopRewriteUtils.replaceChildReference(lparent, hold, hnew); } public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew ) { int pos = getChildReferencePos(parent, inOld); removeChildReferenceByPos(parent, inOld, pos); addChildReference(parent, inNew, pos); parent.refreshSizeInformation(); } public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew, int pos ) { replaceChildReference(parent, inOld, inNew, pos, true); } public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew, int pos, boolean refresh ) { removeChildReferenceByPos(parent, inOld, pos); addChildReference(parent, inNew, pos); if( refresh ) parent.refreshSizeInformation(); } public static void cleanupUnreferenced( Hop... inputs ) { for( Hop input : inputs ) if( input.getParent().isEmpty() ) removeAllChildReferences(input); } public static Hop createDataGenOp( Hop input, double value ) throws HopsException { Hop rows = (input.getDim1()>0) ? new LiteralOp(input.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, input); Hop cols = (input.getDim2()>0) ? new LiteralOp(input.getDim2()) : new UnaryOp("tmpcols", DataType.SCALAR, ValueType.INT, OpOp1.NCOL, input); Hop val = new LiteralOp(value); HashMap<String, Hop> params = new HashMap<String, Hop>(); params.put(DataExpression.RAND_ROWS, rows); params.put(DataExpression.RAND_COLS, cols); params.put(DataExpression.RAND_MIN, val); params.put(DataExpression.RAND_MAX, val); params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM)); params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0)); params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0)); params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); //note internal refresh size information Hop datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params); datagen.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, datagen); if( value==0 ) datagen.setNnz(0); return datagen; } /** * Assumes that min and max are literal ops, needs to be checked from outside. * * @param inputGen input data gen op * @param scale the scale * @param shift the shift * @return data gen op * @throws HopsException if HopsException occurs */ public static DataGenOp copyDataGenOp( DataGenOp inputGen, double scale, double shift ) throws HopsException { HashMap<String, Integer> params = inputGen.getParamIndexMap(); Hop rows = inputGen.getInput().get(params.get(DataExpression.RAND_ROWS)); Hop cols = inputGen.getInput().get(params.get(DataExpression.RAND_COLS)); Hop min = inputGen.getInput().get(params.get(DataExpression.RAND_MIN)); Hop max = inputGen.getInput().get(params.get(DataExpression.RAND_MAX)); Hop pdf = inputGen.getInput().get(params.get(DataExpression.RAND_PDF)); Hop mean = inputGen.getInput().get(params.get(DataExpression.RAND_LAMBDA)); Hop sparsity = inputGen.getInput().get(params.get(DataExpression.RAND_SPARSITY)); Hop seed = inputGen.getInput().get(params.get(DataExpression.RAND_SEED)); //check for literal ops if( !(min instanceof LiteralOp) || !(max instanceof LiteralOp)) return null; //scale and shift double smin = getDoubleValue((LiteralOp) min); double smax = getDoubleValue((LiteralOp) max); smin = smin * scale + shift; smax = smax * scale + shift; Hop sminHop = new LiteralOp(smin); Hop smaxHop = new LiteralOp(smax); HashMap<String, Hop> params2 = new HashMap<String, Hop>(); params2.put(DataExpression.RAND_ROWS, rows); params2.put(DataExpression.RAND_COLS, cols); params2.put(DataExpression.RAND_MIN, sminHop); params2.put(DataExpression.RAND_MAX, smaxHop); params2.put(DataExpression.RAND_PDF, pdf); params2.put(DataExpression.RAND_LAMBDA, mean); params2.put(DataExpression.RAND_SPARSITY, sparsity); params2.put(DataExpression.RAND_SEED, seed ); //note internal refresh size information DataGenOp datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params2); datagen.setOutputBlocksizes(inputGen.getRowsInBlock(), inputGen.getColsInBlock()); copyLineNumbers(inputGen, datagen); if( smin==0 && smax==0 ) datagen.setNnz(0); return datagen; } public static Hop createDataGenOp( Hop rowInput, Hop colInput, double value ) throws HopsException { Hop rows = (rowInput.getDim1()>0) ? new LiteralOp(rowInput.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, rowInput); Hop cols = (colInput.getDim2()>0) ? new LiteralOp(colInput.getDim2()) : new UnaryOp("tmpcols", DataType.SCALAR, ValueType.INT, OpOp1.NCOL, colInput); Hop val = new LiteralOp(value); HashMap<String, Hop> params = new HashMap<String, Hop>(); params.put(DataExpression.RAND_ROWS, rows); params.put(DataExpression.RAND_COLS, cols); params.put(DataExpression.RAND_MIN, val); params.put(DataExpression.RAND_MAX, val); params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM)); params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0)); params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0)); params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); //note internal refresh size information Hop datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params); datagen.setOutputBlocksizes(rowInput.getRowsInBlock(), colInput.getColsInBlock()); copyLineNumbers(rowInput, datagen); if( value==0 ) datagen.setNnz(0); return datagen; } public static Hop createDataGenOp( Hop rowInput, boolean tRowInput, Hop colInput, boolean tColInput, double value ) throws HopsException { long nrow = tRowInput ? rowInput.getDim2() : rowInput.getDim1(); long ncol = tColInput ? colInput.getDim1() : rowInput.getDim2(); Hop rows = (nrow>0) ? new LiteralOp(nrow) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, tRowInput?OpOp1.NCOL:OpOp1.NROW, rowInput); Hop cols = (ncol>0) ? new LiteralOp(ncol) : new UnaryOp("tmpcols", DataType.SCALAR, ValueType.INT, tColInput?OpOp1.NROW:OpOp1.NCOL, colInput); Hop val = new LiteralOp(value); HashMap<String, Hop> params = new HashMap<String, Hop>(); params.put(DataExpression.RAND_ROWS, rows); params.put(DataExpression.RAND_COLS, cols); params.put(DataExpression.RAND_MIN, val); params.put(DataExpression.RAND_MAX, val); params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM)); params.put(DataExpression.RAND_LAMBDA,new LiteralOp(-1.0)); params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0)); params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); //note internal refresh size information Hop datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params); datagen.setOutputBlocksizes(rowInput.getRowsInBlock(), colInput.getColsInBlock()); copyLineNumbers(rowInput, datagen); if( value==0 ) datagen.setNnz(0); return datagen; } public static Hop createDataGenOpByVal( Hop rowInput, Hop colInput, double value ) throws HopsException { Hop val = new LiteralOp(value); HashMap<String, Hop> params = new HashMap<String, Hop>(); params.put(DataExpression.RAND_ROWS, rowInput); params.put(DataExpression.RAND_COLS, colInput); params.put(DataExpression.RAND_MIN, val); params.put(DataExpression.RAND_MAX, val); params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM)); params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0)); params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0)); params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); //note internal refresh size information Hop datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params); datagen.setOutputBlocksizes(rowInput.getRowsInBlock(), colInput.getColsInBlock()); copyLineNumbers(rowInput, datagen); if( value==0 ) datagen.setNnz(0); return datagen; } public static ReorgOp createTranspose(Hop input) { return createReorg(input, ReOrgOp.TRANSPOSE); } public static ReorgOp createReorg(Hop input, ReOrgOp rop) { ReorgOp transpose = new ReorgOp(input.getName(), input.getDataType(), input.getValueType(), rop, input); transpose.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, transpose); transpose.refreshSizeInformation(); return transpose; } public static UnaryOp createUnary(Hop input, OpOp1 type) { DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : (type==OpOp1.CAST_AS_MATRIX) ? DataType.MATRIX : input.getDataType(); ValueType vt = (type==OpOp1.CAST_AS_MATRIX) ? ValueType.DOUBLE : input.getValueType(); UnaryOp unary = new UnaryOp(input.getName(), dt, vt, type, input); unary.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); if( type == OpOp1.CAST_AS_SCALAR || type == OpOp1.CAST_AS_MATRIX ) { int dim = (type==OpOp1.CAST_AS_SCALAR) ? 0 : 1; int blksz = (type==OpOp1.CAST_AS_SCALAR) ? 0 : ConfigurationManager.getBlocksize(); setOutputParameters(unary, dim, dim, blksz, blksz, -1); } copyLineNumbers(input, unary); unary.refreshSizeInformation(); return unary; } public static BinaryOp createBinaryMinus(Hop input) { return createBinary(new LiteralOp(0), input, OpOp2.MINUS); } public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op) { Hop mainInput = input1.getDataType().isMatrix() ? input1 : input2.getDataType().isMatrix() ? input2 : input1; BinaryOp bop = new BinaryOp(mainInput.getName(), mainInput.getDataType(), mainInput.getValueType(), op, input1, input2); //cleanup value type for relational operations if( bop.isPPredOperation() && bop.getDataType().isScalar() ) bop.setValueType(ValueType.BOOLEAN); bop.setOutputBlocksizes(mainInput.getRowsInBlock(), mainInput.getColsInBlock()); copyLineNumbers(mainInput, bop); bop.refreshSizeInformation(); return bop; } public static AggUnaryOp createSum( Hop input ) { return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol); } public static AggUnaryOp createAggUnaryOp( Hop input, AggOp op, Direction dir ) { DataType dt = (dir==Direction.RowCol) ? DataType.SCALAR : input.getDataType(); AggUnaryOp auop = new AggUnaryOp(input.getName(), dt, input.getValueType(), op, dir, input); auop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, auop); auop.refreshSizeInformation(); return auop; } public static AggBinaryOp createMatrixMultiply(Hop left, Hop right) { AggBinaryOp mmult = new AggBinaryOp(left.getName(), left.getDataType(), left.getValueType(), OpOp2.MULT, AggOp.SUM, left, right); mmult.setOutputBlocksizes(left.getRowsInBlock(), right.getColsInBlock()); copyLineNumbers(left, mmult); mmult.refreshSizeInformation(); return mmult; } public static ParameterizedBuiltinOp createParameterizedBuiltinOp(Hop input, HashMap<String,Hop> args, ParamBuiltinOp op) { ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, args); pbop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, pbop); pbop.refreshSizeInformation(); return pbop; } public static Hop createScalarIndexing(Hop input, long rix, long cix) { LiteralOp row = new LiteralOp(rix); LiteralOp col = new LiteralOp(cix); IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true); ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, ix); ix.refreshSizeInformation(); return createUnary(ix, OpOp1.CAST_AS_SCALAR); } public static Hop createValueHop( Hop hop, boolean row ) throws HopsException { Hop ret = null; if( row ){ ret = (hop.getDim1()>0) ? new LiteralOp(hop.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, hop); } else{ ret = (hop.getDim2()>0) ? new LiteralOp(hop.getDim2()) : new UnaryOp("tmpcols", DataType.SCALAR, ValueType.INT, OpOp1.NCOL, hop); } return ret; } public static DataGenOp createSeqDataGenOp( Hop input ) throws HopsException { return createSeqDataGenOp(input, true); } public static DataGenOp createSeqDataGenOp( Hop input, boolean asc ) throws HopsException { Hop to = (input.getDim1()>0) ? new LiteralOp(input.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, input); HashMap<String, Hop> params = new HashMap<String, Hop>(); if( asc ) { params.put(Statement.SEQ_FROM, new LiteralOp(1)); params.put(Statement.SEQ_TO, to); params.put(Statement.SEQ_INCR, new LiteralOp(1)); } else { params.put(Statement.SEQ_FROM, to); params.put(Statement.SEQ_TO, new LiteralOp(1)); params.put(Statement.SEQ_INCR, new LiteralOp(-1)); } //note internal refresh size information DataGenOp datagen = new DataGenOp(DataGenMethod.SEQ, new DataIdentifier("tmp"), params); datagen.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, datagen); return datagen; } public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) { TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright); ternOp.setOutputBlocksizes(mleft.getRowsInBlock(), mleft.getColsInBlock()); copyLineNumbers(mleft, ternOp); ternOp.refreshSizeInformation(); return ternOp; } public static void setOutputParameters( Hop hop, long rlen, long clen, long brlen, long bclen, long nnz ) { hop.setDim1( rlen ); hop.setDim2( clen ); hop.setOutputBlocksizes(brlen, bclen ); hop.setNnz( nnz ); } public static void setOutputParametersForScalar( Hop hop ) { hop.setDataType(DataType.SCALAR); hop.setDim1( 0 ); hop.setDim2( 0 ); hop.setOutputBlocksizes(-1, -1 ); hop.setNnz( -1 ); } public static void refreshOutputParameters( Hop hnew, Hop hold ) { hnew.setDim1( hold.getDim1() ); hnew.setDim2( hold.getDim2() ); hnew.setOutputBlocksizes(hold.getRowsInBlock(), hold.getColsInBlock()); hnew.refreshSizeInformation(); } public static void copyLineNumbers( Hop src, Hop dest ) { dest.setAllPositions(src.getBeginLine(), src.getBeginColumn(), src.getEndLine(), src.getEndColumn()); } public static void updateHopCharacteristics( Hop hop, long brlen, long bclen, Hop src ) { updateHopCharacteristics(hop, brlen, bclen, new MemoTable(), src); } public static void updateHopCharacteristics( Hop hop, long brlen, long bclen, MemoTable memo, Hop src ) { //update block sizes and dimensions hop.setOutputBlocksizes(brlen, bclen); hop.refreshSizeInformation(); //compute memory estimates (for exec type selection) hop.computeMemEstimate(memo); //update line numbers HopRewriteUtils.copyLineNumbers(src, hop); } /////////////////////////////////// // hop size information public static boolean isDimsKnown( Hop hop ) { return ( hop.getDim1()>0 && hop.getDim2()>0 ); } public static boolean isEmpty( Hop hop ) { return ( hop.getNnz()==0 ); } public static boolean isEqualSize( Hop hop1, Hop hop2 ) { return (hop1.dimsKnown() && hop2.dimsKnown() && hop1.getDim1() == hop2.getDim1() && hop1.getDim2() == hop2.getDim2()); } public static boolean isEqualSize( Hop hop1, Hop... hops ) { boolean ret = hop1.dimsKnown(); for( int i=0; i<hops.length && ret; i++ ) ret &= isEqualSize(hop1, hops[i]); return ret; } public static boolean isSingleBlock( Hop hop ) { return isSingleBlock(hop, true) && isSingleBlock(hop, false); } /** * Checks our BLOCKSIZE CONSTRAINT, w/ awareness of forced single node * execution mode. * * @param hop high-level operator * @param cols true if cols * @return true if single block */ public static boolean isSingleBlock( Hop hop, boolean cols ) { //awareness of forced exec single node (e.g., standalone), where we can //guarantee a single block independent of the size because always in CP. if( DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE ) { return true; } //check row- or column-wise single block constraint return cols ? (hop.getDim2()>0 && hop.getDim2()<=hop.getColsInBlock()) : (hop.getDim1()>0 && hop.getDim1()<=hop.getRowsInBlock()); } public static boolean isOuterProductLikeMM( Hop hop ) { return isMatrixMultiply(hop) && hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(0).getDim1() > hop.getInput().get(0).getDim2() && hop.getInput().get(1).getDim1() < hop.getInput().get(1).getDim2(); } public static boolean isSparse( Hop hop ) { return hop.dimsKnown(true) //dims and nnz known && MatrixBlock.evalSparseFormatInMemory(hop.getDim1(), hop.getDim2(), hop.getNnz()); } public static boolean isEqualValue( LiteralOp hop1, LiteralOp hop2 ) throws HopsException { //check for string (no defined double value) if( hop1.getValueType()==ValueType.STRING || hop2.getValueType()==ValueType.STRING ) { return false; } double val1 = getDoubleValue(hop1); double val2 = getDoubleValue(hop2); return ( val1 == val2 ); } public static boolean isNotMatrixVectorBinaryOperation( Hop hop ) { boolean ret = true; if( hop instanceof BinaryOp ) { BinaryOp bop = (BinaryOp) hop; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); boolean mv = (left.getDim1()>1 && right.getDim1()==1) || (left.getDim2()>1 && right.getDim2()==1); ret = isDimsKnown(bop) && !mv; } return ret; } public static boolean isTransposeOperation(Hop hop) { return (hop instanceof ReorgOp && ((ReorgOp)hop).getOp()==ReOrgOp.TRANSPOSE); } public static boolean isTransposeOperation(Hop hop, int maxParents) { return isTransposeOperation(hop) && hop.getParent().size() <= maxParents; } public static boolean containsTransposeOperation(ArrayList<Hop> hops) { boolean ret = false; for( Hop hop : hops ) ret |= isTransposeOperation(hop); return ret; } public static boolean isTransposeOfItself(Hop hop1, Hop hop2) { return isTransposeOperation(hop1) && hop1.getInput().get(0) == hop2 || isTransposeOperation(hop2) && hop2.getInput().get(0) == hop1; } public static boolean isTsmmInput(Hop input) { if( input.getParent().size()==2 ) for(int i=0; i<2; i++) if( isMatrixMultiply(input.getParent().get(i)) && isTransposeOfItself( input.getParent().get(i).getInput().get(0), input.getParent().get(i).getInput().get(1)) ) return true; return false; } public static boolean isBinary(Hop hop, OpOp2 type) { return hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==type; } public static boolean isBinary(Hop hop, OpOp2... types) { return ( hop instanceof BinaryOp && ArrayUtils.contains(types, ((BinaryOp) hop).getOp())); } public static boolean isBinary(Hop hop, OpOp2 type, int maxParents) { return isBinary(hop, type) && hop.getParent().size() <= maxParents; } public static boolean isBinaryMatrixScalarOperation(Hop hop) { return hop instanceof BinaryOp && ((hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isScalar()) ||(hop.getInput().get(1).getDataType().isMatrix() && hop.getInput().get(0).getDataType().isScalar())); } public static boolean isBinaryMatrixMatrixOperation(Hop hop) { return hop instanceof BinaryOp && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1; } public static boolean isBinaryMatrixColVectorOperation(Hop hop) { return hop instanceof BinaryOp && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() && hop.getInput().get(1).getDim2() == 1; } public static boolean isUnary(Hop hop, OpOp1 type) { return hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==type; } public static boolean isUnary(Hop hop, OpOp1 type, int maxParents) { return isUnary(hop, type) && hop.getParent().size() <= maxParents; } public static boolean isUnary(Hop hop, OpOp1... types) { return ( hop instanceof UnaryOp && ArrayUtils.contains(types, ((UnaryOp) hop).getOp())); } public static boolean isMatrixMultiply(Hop hop) { return hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply(); } public static boolean isAggUnaryOp(Hop hop, AggOp...op) { if( !(hop instanceof AggUnaryOp) ) return false; AggOp hopOp = ((AggUnaryOp)hop).getOp(); for( AggOp opi : op ) if( hopOp == opi ) return true; return false; } public static boolean isSum(Hop hop) { return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM); } public static boolean isSumSq(Hop hop) { return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM_SQ); } public static boolean isNonZeroIndicator(Hop pred, Hop hop ) { if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL && pred.getInput().get(0) == hop //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { return true; } return false; } public static boolean checkInputDataTypes(Hop hop, DataType... dt) { for( int i=0; i<hop.getInput().size(); i++ ) if( hop.getInput().get(i).getDataType() != dt[i] ) return false; return true; } public static boolean isFullColumnIndexing(LeftIndexingOp hop) { boolean colPred = hop.getColLowerEqualsUpper(); //single col Hop rl = hop.getInput().get(2); Hop ru = hop.getInput().get(3); return colPred && rl instanceof LiteralOp && getDoubleValueSafe((LiteralOp)rl)==1 && ru instanceof LiteralOp && getDoubleValueSafe((LiteralOp)ru)==hop.getDim1(); } public static boolean isFullRowIndexing(LeftIndexingOp hop) { boolean rowPred = hop.getRowLowerEqualsUpper(); //single row Hop cl = hop.getInput().get(4); Hop cu = hop.getInput().get(5); return rowPred && cl instanceof LiteralOp && getDoubleValueSafe((LiteralOp)cl)==1 && cu instanceof LiteralOp && getDoubleValueSafe((LiteralOp)cu)==hop.getDim2(); } public static boolean isScalarMatrixBinaryMult( Hop hop ) { return hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT && ((hop.getInput().get(0).getDataType()==DataType.SCALAR && hop.getInput().get(1).getDataType()==DataType.MATRIX) || (hop.getInput().get(0).getDataType()==DataType.MATRIX && hop.getInput().get(1).getDataType()==DataType.SCALAR)); } public static boolean isBasic1NSequence(Hop hop) { if( hop instanceof DataGenOp && ((DataGenOp)hop).getOp() == DataGenMethod.SEQ ) { DataGenOp dgop = (DataGenOp) hop; Hop from = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_FROM)); Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR)); return (from instanceof LiteralOp && getDoubleValueSafe((LiteralOp)from)==1) &&(incr instanceof LiteralOp && getDoubleValueSafe((LiteralOp)incr)==1); } return false; } public static boolean isBasic1NSequence(Hop seq, Hop input, boolean row) { if( seq instanceof DataGenOp && ((DataGenOp)seq).getOp() == DataGenMethod.SEQ ) { DataGenOp dgop = (DataGenOp) seq; Hop from = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_FROM)); Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO)); Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR)); return isLiteralOfValue(from, 1) && isLiteralOfValue(incr, 1) && (isLiteralOfValue(to, row?input.getDim1():input.getDim2()) || (to instanceof UnaryOp && ((UnaryOp)to).getOp()==(row? OpOp1.NROW:OpOp1.NCOL) && to.getInput().get(0)==input)); } return false; } public static boolean isBasicN1Sequence(Hop hop) { boolean ret = false; if( hop instanceof DataGenOp ) { DataGenOp dgop = (DataGenOp) hop; if( dgop.getOp() == DataGenMethod.SEQ ){ Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO)); Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR)); ret = (to instanceof LiteralOp && getDoubleValueSafe((LiteralOp)to)==1) &&(incr instanceof LiteralOp && getDoubleValueSafe((LiteralOp)incr)==-1); } } return ret; } public static LiteralOp getBasic1NSequenceMaxLiteral(Hop hop) throws HopsException { if( hop instanceof DataGenOp ) { DataGenOp dgop = (DataGenOp) hop; if( dgop.getOp() == DataGenMethod.SEQ ){ Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO)); if( to instanceof LiteralOp ) return (LiteralOp)to; } } throw new HopsException("Failed to retrieve 'to' argument from basic 1-N sequence."); } public static boolean hasOnlyWriteParents( Hop hop, boolean inclTransient, boolean inclPersistent ) { boolean ret = true; ArrayList<Hop> parents = hop.getParent(); for( Hop p : parents ) { if( inclTransient && inclPersistent ) ret &= ( p instanceof DataOp && (((DataOp)p).getDataOpType()==DataOpTypes.TRANSIENTWRITE || ((DataOp)p).getDataOpType()==DataOpTypes.PERSISTENTWRITE)); else if(inclTransient) ret &= ( p instanceof DataOp && ((DataOp)p).getDataOpType()==DataOpTypes.TRANSIENTWRITE); else if(inclPersistent) ret &= ( p instanceof DataOp && ((DataOp)p).getDataOpType()==DataOpTypes.PERSISTENTWRITE); } return ret; } public static boolean hasTransformParents( Hop hop ) { boolean ret = false; ArrayList<Hop> parents = hop.getParent(); for( Hop p : parents ) { if( p instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)p).getOp()==ParamBuiltinOp.TRANSFORM) { ret = true; } } return ret; } public static boolean alwaysRequiresReblock(Hop hop) { return ( hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTREAD && ((DataOp)hop).getInputFormatType()!=FileFormatTypes.BINARY); } public static boolean rHasSimpleReadChain(Hop root, String var) { if( root.isVisited() ) return false; boolean ret = false; //handle leaf node for variable if( root instanceof DataOp && ((DataOp)root).isRead() && root.getName().equals(var) ) { ret = (root.getParent().size()<=1); } //recursively process childs (on the entire path to var, all //intermediates are supposed to have at most one consumer, but //side-ways inputs can have arbitrary dag structures) for( Hop c : root.getInput() ) { if( rHasSimpleReadChain(c, var) ) ret |= root.getParent().size()<=1; } root.setVisited(); return ret; } public static boolean rContainsRead(Hop root, String var, boolean includeMetaOp) { if( root.isVisited() ) return false; boolean ret = false; //handle leaf node for variable if( root instanceof DataOp && ((DataOp)root).isRead() && root.getName().equals(var) ) { boolean onlyMetaOp = true; if( !includeMetaOp ){ for( Hop p : root.getParent() ) { onlyMetaOp &= (p instanceof UnaryOp && (((UnaryOp)p).getOp()==OpOp1.NROW || ((UnaryOp)p).getOp()==OpOp1.NCOL) ); } ret = !onlyMetaOp; } else ret = true; } //recursively process childs for( Hop c : root.getInput() ) ret |= rContainsRead(c, var, includeMetaOp); root.setVisited(); return ret; } ////////////////////////////////////// // utils for lookup tables public static boolean isValidOp( AggOp input, AggOp... validTab ) { return ArrayUtils.contains(validTab, input); } public static boolean isValidOp( OpOp1 input, OpOp1... validTab ) { return ArrayUtils.contains(validTab, input); } public static boolean isValidOp( OpOp2 input, OpOp2... validTab ) { return ArrayUtils.contains(validTab, input); } public static boolean isValidOp( ReOrgOp input, ReOrgOp... validTab ) { return ArrayUtils.contains(validTab, input); } public static boolean isValidOp( ParamBuiltinOp input, ParamBuiltinOp... validTab ) { return ArrayUtils.contains(validTab, input); } public static int getValidOpPos( OpOp2 input, OpOp2... validTab ) { return ArrayUtils.indexOf(validTab, input); } /** * Compares the size of outputs from hop1 and hop2, in terms of number * of matrix cells. Note that this methods throws a RuntimeException * if either hop has unknown dimensions. * * @param hop1 high-level operator 1 * @param hop2 high-level operator 2 * @return 0 if sizes are equal, <0 for hop1<hop2, >0 for hop1>hop2. */ public static int compareSize( Hop hop1, Hop hop2 ) { long size1 = hop1.getDim1() * hop1.getDim2(); long size2 = hop2.getDim1() * hop2.getDim2(); return Long.compare(size1, size2); } }