/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.recompile; import java.util.ArrayList; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.utils.Statistics; public class LiteralReplacement { //internal configuration parameters private static final long REPLACE_LITERALS_MAX_MATRIX_SIZE = 1000000; //10^6 cells (8MB) private static final boolean REPORT_LITERAL_REPLACE_OPS_STATS = true; protected static void rReplaceLiterals( Hop hop, LocalVariableMap vars, boolean scalarsOnly ) throws DMLRuntimeException { if( hop.isVisited() ) return; if( hop.getInput() != null ) { //indexed access to allow parent-child modifications for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); Hop lit = null; //conditional apply of literal replacements lit = (lit==null) ? replaceLiteralScalarRead(c, vars) : lit; lit = (lit==null) ? replaceLiteralValueTypeCastScalarRead(c, vars) : lit; lit = (lit==null) ? replaceLiteralValueTypeCastLiteral(c, vars) : lit; if( !scalarsOnly ) { lit = (lit==null) ? replaceLiteralDataTypeCastMatrixRead(c, vars) : lit; lit = (lit==null) ? replaceLiteralValueTypeCastRightIndexing(c, vars) : lit; lit = (lit==null) ? replaceLiteralFullUnaryAggregate(c, vars) : lit; lit = (lit==null) ? replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit; } //replace hop w/ literal on demand if( lit != null ) { //replace hop c by literal, for all parents to prevent (1) missed opportunities //because hop c marked as visited, and (2) repeated evaluation of uagg ops if( c.getParent().size() > 1 ) { //multiple parents ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent()); for( Hop p : parents ) { int pos = HopRewriteUtils.getChildReferencePos(p, c); HopRewriteUtils.removeChildReferenceByPos(p, c, pos); HopRewriteUtils.addChildReference(p, lit, pos); } } else { //current hop is only parent HopRewriteUtils.replaceChildReference(hop, c, lit, i); } } //recursively process children else { rReplaceLiterals(c, vars, scalarsOnly); } } } hop.setVisited(); } /////////////////////////////// // Literal replacement rules /////////////////////////////// private static LiteralOp replaceLiteralScalarRead(Hop c, LocalVariableMap vars) { LiteralOp ret = null; //scalar read - literal replacement if( c instanceof DataOp && ((DataOp)c).getDataOpType() != DataOpTypes.PERSISTENTREAD && c.getDataType()==DataType.SCALAR ) { Data dat = vars.get(c.getName()); if( dat != null ) //required for selective constant propagation { ScalarObject sdat = (ScalarObject)dat; switch( sdat.getValueType() ) { case INT: ret = new LiteralOp(sdat.getLongValue()); break; case DOUBLE: ret = new LiteralOp(sdat.getDoubleValue()); break; case BOOLEAN: ret = new LiteralOp(sdat.getBooleanValue()); break; default: //otherwise: do nothing } } } return ret; } private static LiteralOp replaceLiteralValueTypeCastScalarRead( Hop c, LocalVariableMap vars ) { LiteralOp ret = null; //as.double/as.integer/as.boolean over scalar read - literal replacement if( c instanceof UnaryOp && (((UnaryOp)c).getOp() == OpOp1.CAST_AS_DOUBLE || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_BOOLEAN ) && c.getInput().get(0) instanceof DataOp && c.getDataType()==DataType.SCALAR ) { Data dat = vars.get(c.getInput().get(0).getName()); if( dat != null ) //required for selective constant propagation { ScalarObject sdat = (ScalarObject)dat; UnaryOp cast = (UnaryOp) c; switch( cast.getOp() ) { case CAST_AS_INT: ret = new LiteralOp(sdat.getLongValue()); break; case CAST_AS_DOUBLE: ret = new LiteralOp(sdat.getDoubleValue()); break; case CAST_AS_BOOLEAN: ret = new LiteralOp(sdat.getBooleanValue()); break; default: //otherwise: do nothing } } } return ret; } private static LiteralOp replaceLiteralValueTypeCastLiteral( Hop c, LocalVariableMap vars ) throws DMLRuntimeException { LiteralOp ret = null; //as.double/as.integer/as.boolean over scalar literal (potentially created by other replacement //rewrite in same dag) - literal replacement if( c instanceof UnaryOp && (((UnaryOp)c).getOp() == OpOp1.CAST_AS_DOUBLE || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_BOOLEAN ) && c.getInput().get(0) instanceof LiteralOp ) { LiteralOp sdat = (LiteralOp)c.getInput().get(0); UnaryOp cast = (UnaryOp) c; try { switch( cast.getOp() ) { case CAST_AS_INT: long ival = HopRewriteUtils.getIntValue(sdat); ret = new LiteralOp(ival); break; case CAST_AS_DOUBLE: double dval = HopRewriteUtils.getDoubleValue(sdat); ret = new LiteralOp(dval); break; case CAST_AS_BOOLEAN: boolean bval = HopRewriteUtils.getBooleanValue(sdat); ret = new LiteralOp(bval); break; default: //otherwise: do nothing } } catch(HopsException ex) { throw new DMLRuntimeException(ex); } } return ret; } private static LiteralOp replaceLiteralDataTypeCastMatrixRead( Hop c, LocalVariableMap vars ) throws DMLRuntimeException { LiteralOp ret = null; //as.scalar/matrix read - literal replacement if( c instanceof UnaryOp && ((UnaryOp)c).getOp() == OpOp1.CAST_AS_SCALAR && c.getInput().get(0) instanceof DataOp && c.getInput().get(0).getDataType() == DataType.MATRIX ) { Data dat = vars.get(c.getInput().get(0).getName()); if( dat != null ) //required for selective constant propagation { //cast as scalar (see VariableCPInstruction) MatrixObject mo = (MatrixObject)dat; MatrixBlock mBlock = mo.acquireRead(); if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 ) throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar."); double value = mBlock.getValue(0,0); mo.release(); //literal substitution (always double) ret = new LiteralOp(value); } } return ret; } private static LiteralOp replaceLiteralValueTypeCastRightIndexing( Hop c, LocalVariableMap vars ) throws DMLRuntimeException { LiteralOp ret = null; //as.scalar/right indexing w/ literals/vars and matrix less than 10^6 cells if( c instanceof UnaryOp && ((UnaryOp)c).getOp() == OpOp1.CAST_AS_SCALAR && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getDataType() == DataType.MATRIX) { IndexingOp rix = (IndexingOp)c.getInput().get(0); Hop data = rix.getInput().get(0); Hop rl = rix.getInput().get(1); Hop ru = rix.getInput().get(2); Hop cl = rix.getInput().get(3); Hop cu = rix.getInput().get(4); if( rix.dimsKnown() && rix.getDim1()==1 && rix.getDim2()==1 && data instanceof DataOp && vars.keySet().contains(data.getName()) && isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) && isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars) ) { long rlval = getIntValueDataLiteral(rl, vars); long clval = getIntValueDataLiteral(cl, vars); MatrixObject mo = (MatrixObject)vars.get(data.getName()); //get the dimension information from the matrix object because the hop //dimensions might not have been updated during recompile if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE ) { MatrixBlock mBlock = mo.acquireRead(); double value = mBlock.getValue((int)rlval-1,(int)clval-1); mo.release(); //literal substitution (always double) ret = new LiteralOp(value); } } } return ret; } private static LiteralOp replaceLiteralFullUnaryAggregate( Hop c, LocalVariableMap vars ) throws DMLRuntimeException { LiteralOp ret = null; //full unary aggregate w/ matrix less than 10^6 cells if( c instanceof AggUnaryOp && isReplaceableUnaryAggregate((AggUnaryOp)c) && c.getInput().get(0) instanceof DataOp && vars.keySet().contains(c.getInput().get(0).getName()) ) { Hop data = c.getInput().get(0); MatrixObject mo = (MatrixObject) vars.get(data.getName()); //get the dimension information from the matrix object because the hop //dimensions might not have been updated during recompile if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE ) { MatrixBlock mBlock = mo.acquireRead(); double value = replaceUnaryAggregate((AggUnaryOp)c, mBlock); mo.release(); //literal substitution (always double) ret = new LiteralOp(value); } } return ret; } private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing( Hop c, LocalVariableMap vars ) throws DMLRuntimeException { LiteralOp ret = null; //full unary aggregate w/ indexed matrix less than 10^6 cells if( c instanceof AggUnaryOp && isReplaceableUnaryAggregate((AggUnaryOp)c) && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getInput().get(0) instanceof DataOp ) { IndexingOp rix = (IndexingOp)c.getInput().get(0); Hop data = rix.getInput().get(0); Hop rl = rix.getInput().get(1); Hop ru = rix.getInput().get(2); Hop cl = rix.getInput().get(3); Hop cu = rix.getInput().get(4); if( data instanceof DataOp && vars.keySet().contains(data.getName()) && isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) && isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars) ) { long rlval = getIntValueDataLiteral(rl, vars); long ruval = getIntValueDataLiteral(ru, vars); long clval = getIntValueDataLiteral(cl, vars); long cuval = getIntValueDataLiteral(cu, vars); MatrixObject mo = (MatrixObject) vars.get(data.getName()); //get the dimension information from the matrix object because the hop //dimensions might not have been updated during recompile if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE ) { MatrixBlock mBlock = mo.acquireRead(); MatrixBlock mBlock2 = mBlock.sliceOperations((int)(rlval-1), (int)(ruval-1), (int)(clval-1), (int)(cuval-1), new MatrixBlock()); double value = replaceUnaryAggregate((AggUnaryOp)c, mBlock2); mo.release(); //literal substitution (always double) ret = new LiteralOp(value); } } } return ret; } /////////////////////////////// // Utility functions /////////////////////////////// private static boolean isIntValueDataLiteral(Hop h, LocalVariableMap vars) { return ( (h instanceof DataOp && vars.keySet().contains(h.getName())) || h instanceof LiteralOp ||(h instanceof UnaryOp && (((UnaryOp)h).getOp()==OpOp1.NROW || ((UnaryOp)h).getOp()==OpOp1.NCOL) && h.getInput().get(0) instanceof DataOp && vars.keySet().contains(h.getInput().get(0).getName())) ); } private static long getIntValueDataLiteral(Hop hop, LocalVariableMap vars) throws DMLRuntimeException { long value = -1; try { if( hop instanceof LiteralOp ) { value = HopRewriteUtils.getIntValue((LiteralOp)hop); } else if( hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.NROW ) { //get the dimension information from the matrix object because the hop //dimensions might not have been updated during recompile MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName()); value = mo.getNumRows(); } else if( hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.NCOL ) { //get the dimension information from the matrix object because the hop //dimensions might not have been updated during recompile MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName()); value = mo.getNumColumns(); } else { ScalarObject sdat = (ScalarObject) vars.get(hop.getName()); value = sdat.getLongValue(); } } catch(HopsException ex) { throw new DMLRuntimeException("Failed to get int value for literal replacement", ex); } return value; } private static boolean isReplaceableUnaryAggregate( AggUnaryOp auop ) { boolean cdir = (auop.getDirection() == Direction.RowCol); boolean cop = ( auop.getOp() == AggOp.SUM || auop.getOp() == AggOp.SUM_SQ || auop.getOp() == AggOp.MIN || auop.getOp() == AggOp.MAX ); return cdir && cop; } private static double replaceUnaryAggregate( AggUnaryOp auop, MatrixBlock mb ) throws DMLRuntimeException { //setup stats reporting if necessary boolean REPORT_STATS = (DMLScript.STATISTICS && REPORT_LITERAL_REPLACE_OPS_STATS); long t0 = REPORT_STATS ? System.nanoTime() : 0; //compute required unary aggregate double val = Double.MAX_VALUE; switch( auop.getOp() ) { case SUM: val = mb.sum(); break; case SUM_SQ: val = mb.sumSq(); break; case MIN: val = mb.min(); break; case MAX: val = mb.max(); break; default: throw new DMLRuntimeException("Unsupported unary aggregate replacement: "+auop.getOp()); } //report statistics if necessary if( REPORT_STATS ){ long t1 = System.nanoTime(); Statistics.maintainCPHeavyHitters("rlit", t1-t0); } return val; } }