/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Statement; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; /** * Rule: Algebraic Simplifications. Simplifies binary expressions * in terms of two major purposes: (1) rewrite binary operations * to unary operations when possible (in CP this reduces the memory * estimate, in MR this allows map-only operations and hence prevents * unnecessary shuffle and sort) and (2) remove binary operations that * are in itself are unnecessary (e.g., *1 and /1). * */ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationStatic.class.getName()); //valid aggregation operation types for rowOp to colOp conversions and vice versa private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; //valid binary operations for distributive and associate reorderings private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MINUS}; private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT}; @Override public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { if( roots == null ) return roots; //one pass rewrite-descend (rewrite created pattern) for( Hop h : roots ) rule_AlgebraicSimplification( h, false ); Hop.resetVisitStatus(roots); //one pass descend-rewrite (for rollup) for( Hop h : roots ) rule_AlgebraicSimplification( h, true ); return roots; } @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { if( root == null ) return root; //one pass rewrite-descend (rewrite created pattern) rule_AlgebraicSimplification( root, false ); root.resetVisitStatus(); //one pass descend-rewrite (for rollup) rule_AlgebraicSimplification( root, true ); return root; } /** * Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however, * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should * come before constant folding while the other simplifications should come after constant * folding. Hence, not applied yet. * * @param hop high-level operator * @param descendFirst if process children recursively first * @throws HopsException if HopsException occurs */ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) throws HopsException { if(hop.isVisited()) return; //recursively process children for( int i=0; i<hop.getInput().size(); i++) { Hop hi = hop.getInput().get(i); //process childs recursively first (to allow roll-up) if( descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); //see below //apply actual simplification rewrites (of childs incl checks) hi = removeUnnecessaryVectorizeOperation(hi); //e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize) hi = fuseDatagenAndBinaryOperation(hi); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7) hi = fuseDatagenAndMinusOperation(hi); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2) hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X) hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq; hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq; hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C) hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count") if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean hi = fuseLogNzUnaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X) -> log_nz(X) hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) } hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); } hop.setVisited(); } private Hop removeUnnecessaryVectorizeOperation(Hop hi) { //applies to all binary matrix operations, if one input is unnecessarily vectorized if( hi instanceof BinaryOp && hi.getDataType()==DataType.MATRIX && ((BinaryOp)hi).supportsMatrixScalarOperations() ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //NOTE: these rewrites of binary cell operations need to be aware that right is //potentially a vector but the result is of the size of left //TODO move to dynamic rewrites (since size dependent to account for mv binary cell and outer operations) if( !(left.getDim1()>1 && left.getDim2()==1 && right.getDim1()==1 && right.getDim2()>1) ) // no outer { //check and remove right vectorized scalar if( left.getDataType() == DataType.MATRIX && right instanceof DataGenOp ) { DataGenOp dright = (DataGenOp) right; if( dright.getOp()==DataGenMethod.RAND && dright.hasConstantValue() ) { Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN)); HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1); HopRewriteUtils.cleanupUnreferenced(dright); LOG.debug("Applied removeUnnecessaryVectorizeOperation1"); } } //check and remove left vectorized scalar else if( right.getDataType() == DataType.MATRIX && left instanceof DataGenOp ) { DataGenOp dleft = (DataGenOp) left; if( dleft.getOp()==DataGenMethod.RAND && dleft.hasConstantValue() && (left.getDim2()==1 || right.getDim2()>1) && (left.getDim1()==1 || right.getDim1()>1)) { Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN)); HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0); HopRewriteUtils.cleanupUnreferenced(dleft); LOG.debug("Applied removeUnnecessaryVectorizeOperation2"); } } //Note: we applied this rewrite to at most one side in order to keep the //output semantically equivalent. However, future extensions might consider //to remove vectors from both side, compute the binary op on scalars and //finally feed it into a datagenop of the original dimensions. } } return hi; } /** * handle removal of unnecessary binary operations * * X/1 or X*1 or 1*X or X-0 -> X * -1*X or X*-1-> -X * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //X/1 or X*1 -> X if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT ) { HopRewriteUtils.replaceChildReference(parent, bop, left, pos); hi = left; LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")"); } } //X-0 -> X else if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==0.0 ) { if( bop.getOp()==OpOp2.MINUS ) { HopRewriteUtils.replaceChildReference(parent, bop, left, pos); hi = left; LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line "+bop.getBeginLine()+")"); } } //1*X -> X else if( right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.MULT ) { HopRewriteUtils.replaceChildReference(parent, bop, right, pos); hi = right; LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line "+bop.getBeginLine()+")"); } } //-1*X -> -X //note: this rewrite is necessary since the new antlr parser always converts //-X to -1*X due to mechanical reasons else if( right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==-1.0 ) { if( bop.getOp()==OpOp2.MULT ) { bop.setOp(OpOp2.MINUS); HopRewriteUtils.replaceChildReference(bop, left, new LiteralOp(0), 0); hi = bop; LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line "+bop.getBeginLine()+")"); } } //X*-1 -> -X (see comment above) else if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==-1.0 ) { if( bop.getOp()==OpOp2.MULT ) { bop.setOp(OpOp2.MINUS); HopRewriteUtils.removeChildReferenceByPos(bop, right, 1); HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0); hi = bop; LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line "+bop.getBeginLine()+")"); } } } return hi; } /** * Handle removal of unnecessary binary operations over rand data * * rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7)) * 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7); * * @param hi high-order operaton * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop fuseDatagenAndBinaryOperation( Hop hi ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //NOTE: rewrite not applied if more than one datagen consumer because this would lead to //the creation of multiple datagen ops and thus potentially different results if seed not specified) //left input rand and hence output matrix double, right scalar literal if( left instanceof DataGenOp && ((DataGenOp)left).getOp()==DataGenMethod.RAND && right instanceof LiteralOp && left.getParent().size()==1 ) { DataGenOp inputGen = (DataGenOp)left; HashMap<String,Integer> params = inputGen.getParamIndexMap(); Hop pdf = left.getInput().get(params.get(DataExpression.RAND_PDF)); Hop min = left.getInput().get(params.get(DataExpression.RAND_MIN)); Hop max = left.getInput().get(params.get(DataExpression.RAND_MAX)); double sval = ((LiteralOp)right).getDoubleValue(); if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS || bop.getOp() == OpOp2.MINUS) && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) { //create fused data gen operator DataGenOp gen = null; if( bop.getOp()==OpOp2.MULT ) gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); else { //OpOp2.PLUS | OpOp2.MINUS sval *= (bop.getOp()==OpOp2.MINUS) ? -1 : 1; gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval); } //rewire all parents (avoid anomalies with replicated datagen) List<Hop> parents = new ArrayList<Hop>(bop.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation1 (line "+bop.getBeginLine()+")."); } } //right input rand and hence output matrix double, left scalar literal else if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND && left instanceof LiteralOp && right.getParent().size()==1 ) { DataGenOp inputGen = (DataGenOp)right; HashMap<String,Integer> params = inputGen.getParamIndexMap(); Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF)); Hop min = right.getInput().get(params.get(DataExpression.RAND_MIN)); Hop max = right.getInput().get(params.get(DataExpression.RAND_MAX)); double sval = ((LiteralOp)left).getDoubleValue(); if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) { //create fused data gen operator DataGenOp gen = null; if( bop.getOp()==OpOp2.MULT ) gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); else { //OpOp2.PLUS gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval); } //rewire all parents (avoid anomalies with replicated datagen) List<Hop> parents = new ArrayList<Hop>(bop.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, gen); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+")."); } } } return hi; } private Hop fuseDatagenAndMinusOperation( Hop hi ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==0.0 ) { DataGenOp inputGen = (DataGenOp)right; HashMap<String,Integer> params = inputGen.getParamIndexMap(); Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF)); int ixMin = params.get(DataExpression.RAND_MIN); int ixMax = params.get(DataExpression.RAND_MAX); Hop min = right.getInput().get(ixMin); Hop max = right.getInput().get(ixMax); //apply rewrite under additional conditions (for simplicity) if( inputGen.getParent().size()==1 && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) { //exchange and *-1 (special case 0 stays 0 instead of -0 for consistency) double newMinVal = (((LiteralOp)max).getDoubleValue()==0)?0:(-1 * ((LiteralOp)max).getDoubleValue()); double newMaxVal = (((LiteralOp)min).getDoubleValue()==0)?0:(-1 * ((LiteralOp)min).getDoubleValue()); Hop newMin = new LiteralOp(newMinVal); Hop newMax = new LiteralOp(newMaxVal); HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin); HopRewriteUtils.addChildReference(inputGen, newMin, ixMin); HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax); HopRewriteUtils.addChildReference(inputGen, newMax, ixMax); //rewire all parents (avoid anomalies with replicated datagen) List<Hop> parents = new ArrayList<Hop>(bop.getParent()); for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, bop, inputGen); hi = inputGen; LOG.debug("Applied fuseDatagenAndMinusOperation (line "+bop.getBeginLine()+")."); } } } return hi; } /** * Handle simplification of binary operations (relies on previous common subexpression elimination). * At the same time this servers as a canonicalization for more complex rewrites. * * X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X) * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); //patterns: X+X -> X*2, X*X -> X^2, if( left == right && left.getDataType()==DataType.MATRIX ) { //note: we simplify this to unary operations first (less mem and better MR plan), //however, we later compile specific LOPS for X*2 and X^2 if( bop.getOp()==OpOp2.PLUS ) //X+X -> X*2 { bop.setOp(OpOp2.MULT); LiteralOp tmp = new LiteralOp(2); bop.getInput().remove(1); right.getParent().remove(bop); HopRewriteUtils.addChildReference(hi, tmp, 1); LOG.debug("Applied simplifyBinaryToUnaryOperation1"); } else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2 { bop.setOp(OpOp2.POW); LiteralOp tmp = new LiteralOp(2); bop.getInput().remove(1); right.getParent().remove(bop); HopRewriteUtils.addChildReference(hi, tmp, 1); LOG.debug("Applied simplifyBinaryToUnaryOperation2"); } } //patterns: (X>0)-(X<0) -> sign(X) else if( bop.getOp() == OpOp2.MINUS && HopRewriteUtils.isBinary(left, OpOp2.GREATER) && HopRewriteUtils.isBinary(right, OpOp2.LESS) && left.getInput().get(0) == right.getInput().get(0) && left.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0 && right.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 ) { UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN); HopRewriteUtils.replaceChildReference(parent, hi, uop, pos); HopRewriteUtils.cleanupUnreferenced(hi, left, right); hi = uop; LOG.debug("Applied simplifyBinaryToUnaryOperation3"); } } return hi; } /** * Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and * U%*%V-eps into the common representation U%*%V+s which simplifies * subsequent rewrites (e.g., wdivmm or wcemm with epsilon). * * @param hi high-level operator * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop canonicalizeMatrixMultScalarAdd( Hop hi ) throws HopsException { //pattern: binary operation (+ or -) of matrix mult and scalar if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); //pattern: (eps + U%*%V) -> (U%*%V+eps) if( left.getDataType().isScalar() && right instanceof AggBinaryOp && bop.getOp()==OpOp2.PLUS ) { HopRewriteUtils.removeAllChildReferences(bop); HopRewriteUtils.addChildReference(bop, right, 0); HopRewriteUtils.addChildReference(bop, left, 1); LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line "+hi.getBeginLine()+")."); } //pattern: (U%*%V - eps) -> (U%*%V + (-eps)) else if( right.getDataType().isScalar() && left instanceof AggBinaryOp && bop.getOp() == OpOp2.MINUS ) { bop.setOp(OpOp2.PLUS); HopRewriteUtils.replaceChildReference(bop, right, HopRewriteUtils.createBinaryMinus(right), 1); LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+")."); } } return hi; } /** * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static * rewrite in order to apply it before splitting dags which would hide the table information * if dimensions are not specified. * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop simplifyReverseOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof AggBinaryOp && hi.getInput().get(0) instanceof TernaryOp ) { TernaryOp top = (TernaryOp) hi.getInput().get(0); if( top.getOp()==OpOp3.CTABLE && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) && top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1()) { ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV); HopRewriteUtils.replaceChildReference(parent, hi, rop, pos); HopRewriteUtils.cleanupUnreferenced(hi, top); hi = rop; LOG.debug("Applied simplifyReverseOperation."); } } return hi; } private Hop simplifyMultiBinaryToBinaryOperation( Hop hi ) { //pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate) if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1 && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) && hi.getInput().get(1).getParent().size() == 1 ) //single consumer { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(1).getInput().get(0); Hop right = hi.getInput().get(1).getInput().get(1); //set new binaryop type and rewire inputs bop.setOp(OpOp2.MINUS1_MULT); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.addChildReference(bop, left); HopRewriteUtils.addChildReference(bop, right); LOG.debug("Applied simplifyMultiBinaryToBinaryOperation."); } return hi; } /** * (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X * (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X * * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ private Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X //(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X boolean applied = false; if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) ) { Hop X = null; Hop Y = null; if( HopRewriteUtils.isBinary(left, OpOp2.MULT) ) //(Y*X-X) -> (Y-1)*X { Hop leftC1 = left.getInput().get(0); Hop leftC2 = left.getInput().get(1); if( leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX && (right == leftC1 || right == leftC2) && leftC1 !=leftC2 ){ //any mult order X = right; Y = ( right == leftC1 ) ? leftC2 : leftC1; } if( X != null ){ //rewrite 'binary +/-' LiteralOp literal = new LiteralOp(1); BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp()); BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT); HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); HopRewriteUtils.cleanupUnreferenced(hi, left); hi = mult; applied = true; LOG.debug("Applied simplifyDistributiveBinaryOperation1"); } } if( !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X { Hop rightC1 = right.getInput().get(0); Hop rightC2 = right.getInput().get(1); if( rightC1.getDataType()==DataType.MATRIX && rightC2.getDataType()==DataType.MATRIX && (left == rightC1 || left == rightC2) && rightC1 !=rightC2 ){ //any mult order X = left; Y = ( left == rightC1 ) ? rightC2 : rightC1; } if( X != null ){ //rewrite '+/- binary' LiteralOp literal = new LiteralOp(1); BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp()); BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT); HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); HopRewriteUtils.cleanupUnreferenced(hi, right); hi = mult; LOG.debug("Applied simplifyDistributiveBinaryOperation2"); } } } } return hi; } /** * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v) * * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too * eagerly, which would loose additional rewrite potential. This rewrite has two goals * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees. * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ private Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp && parent instanceof AggBinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); OpOp2 op = bop.getOp(); if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY) ) { boolean applied = false; if( right instanceof BinaryOp ) { BinaryOp bop2 = (BinaryOp)right; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); OpOp2 op2 = bop2.getOp(); if( op==op2 && right2.getDataType()==DataType.MATRIX && (right2 instanceof AggBinaryOp) ) { //(X*(Y*op()) -> (X*Y)*op() BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op); BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op); HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2); hi = bop4; applied = true; LOG.debug("Applied simplifyBushyBinaryOperation1"); } } if( !applied && left instanceof BinaryOp ) { BinaryOp bop2 = (BinaryOp)left; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); OpOp2 op2 = bop2.getOp(); if( op==op2 && left2.getDataType()==DataType.MATRIX && (left2 instanceof AggBinaryOp) && (right2.getDim2() > 1 || right.getDim2() == 1) //X not vector, or Y vector && (right2.getDim1() > 1 || right.getDim1() == 1) ) //X not vector, or Y vector { //((op()*X)*Y) -> op()*(X*Y) BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op); BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op); HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2); hi = bop4; LOG.debug("Applied simplifyBushyBinaryOperation2"); } } } } return hi; } private Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg && hi.getInput().get(0) instanceof ReorgOp ) //reorg operation { ReorgOp rop = (ReorgOp)hi.getInput().get(0); if( (rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE || rop.getOp() == ReOrgOp.REV ) //valid reorg && rop.getParent().size()==1 ) //uagg only reorg consumer { Hop input = rop.getInput().get(0); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.removeAllChildReferences(rop); HopRewriteUtils.addChildReference(hi, input); LOG.debug("Applied simplifyUnaryAggReorgOperation"); } } return hi; } private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) && hi.getInput().get(0) instanceof BinaryOp ) { BinaryOp bin = (BinaryOp) hi.getInput().get(0); BinaryOp bout = null; //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y) if( bin.getInput().get(0).getDataType()==DataType.MATRIX && bin.getInput().get(1).getDataType()==DataType.MATRIX ) { UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp()); } //as.scalar(X*s) -> as.scalar(X) * s else if( bin.getInput().get(0).getDataType()==DataType.MATRIX ) { UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp()); } //as.scalar(s*X) -> s * as.scalar(X) else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) { UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp()); } if( bout != null ) { HopRewriteUtils.replaceChildReference(parent, hi, bout, pos); LOG.debug("Applied simplifyBinaryMatrixScalarOperation."); } } return hi; } private Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof AggUnaryOp && hi.getParent().size()==1 && (((AggUnaryOp) hi).getDirection()==Direction.Row || ((AggUnaryOp) hi).getDirection()==Direction.Col) && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { AggUnaryOp uagg = (AggUnaryOp) hi; //get input rewire existing operators (remove inner transpose) Hop input = uagg.getInput().get(0).getInput().get(0); HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0)); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); //pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X)) if( uagg.getDirection()==Direction.Row ) { uagg.setDirection(Direction.Col); LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+")."); } //pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X)) else if( uagg.getDirection()==Direction.Col ) { uagg.setDirection(Direction.Row); LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line "+hi.getBeginLine()+")."); } //create outer transpose operation and rewire operators HopRewriteUtils.addChildReference(uagg, input); uagg.refreshSizeInformation(); Hop trans = HopRewriteUtils.createTranspose(uagg); //incl refresh size HopRewriteUtils.addChildReference(parent, trans, pos); //by def, same size hi = trans; } return hi; } private Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, int pos ) { // a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) // probed at root node of b in above example // (with support for left or right scalar operations) if( HopRewriteUtils.isTransposeOperation(hi, 1) && HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0)) && hi.getInput().get(0).getParent().size()==1) { int Xpos = hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1; Hop X = hi.getInput().get(0).getInput().get(Xpos); BinaryOp binary = (BinaryOp) hi.getInput().get(0); if( HopRewriteUtils.containsTransposeOperation(X.getParent()) && !HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.CENTRALMOMENT, OpOp2.QUANTILE})) { //clear existing wiring HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.removeChildReference(hi, binary); HopRewriteUtils.removeChildReference(binary, X); //re-wire operators HopRewriteUtils.addChildReference(parent, binary, pos); HopRewriteUtils.addChildReference(binary, hi, Xpos); HopRewriteUtils.addChildReference(hi, X); //note: common subexpression later eliminated by dedicated rewrite hi = binary; LOG.debug("Applied pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+")."); } } return hi; } private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws HopsException { //pattern: sum(lamda*X) -> lamda*sum(X) if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol && ((AggUnaryOp)hi).getOp()==Hop.AggOp.SUM // only one parent which is the sum && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) && ((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX) ||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR))) { Hop operand1 = hi.getInput().get(0).getInput().get(0); Hop operand2 = hi.getInput().get(0).getInput().get(1); //check which operand is the Scalar and which is the matrix Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol); Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); LOG.debug("Applied pushdownSumBinaryMult."); return bop; } return hi; } private Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX //unaryop && hi.getInput().get(0) instanceof BinaryOp //binaryop - ppred && ((BinaryOp)hi.getInput().get(0)).isPPredOperation() ) { UnaryOp uop = (UnaryOp) hi; //valid unary op if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN || uop.getOp()==OpOp1.SELP || uop.getOp()==OpOp1.CEIL || uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND ) { //clear link unary-binary Hop input = uop.getInput().get(0); HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied simplifyUnaryPPredOperation."); } } return hi; } private Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos ) { //e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B) if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted && hi.getInput().get(0) instanceof BinaryOp && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind) || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND) && hi.getInput().get(0).getParent().size() == 1 ) //single consumer of append { BinaryOp bop = (BinaryOp)hi.getInput().get(0); //both inputs transpose ops, where transpose is single consumer if( HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1) && HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) ) { Hop left = bop.getInput().get(0).getInput().get(0); Hop right = bop.getInput().get(1).getInput().get(0); //create new subdag (no in-place dag update to prevent anomalies with //multiple consumers during rewrite process) OpOp2 binop = (bop.getOp()==OpOp2.CBIND) ? OpOp2.RBIND : OpOp2.CBIND; BinaryOp bopnew = HopRewriteUtils.createBinary(left, right, binop); HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos); hi = bopnew; LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+")."); } } return hi; } /** * handle simplification of more complex sub DAG to unary operation. * * X*(1-X) -> sprop(X) * (1-X)*X -> sprop(X) * 1/(1+exp(-X)) -> sigmoid(X) * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @throws HopsException if HopsException occurs */ private Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); boolean applied = false; //sample proportion (sprop) operator if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { //by definition, either left or right or none applies. //note: if there are multiple consumers on the intermediate, //we follow the heuristic that redundant computation is more beneficial, //i.e., we still fuse but leave the intermediate for the other consumers if( left instanceof BinaryOp ) //(1-X)*X { BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); Hop left2 = bleft.getInput().get(1); if( left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 && left2 == right && bleft.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1"); } } if( !applied && right instanceof BinaryOp ) //X*(1-X) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); Hop right2 = bright.getInput().get(1); if( right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 && right2 == left && bright.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2"); } } } //sigmoid operator if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp) { //note: if there are multiple consumers on the intermediate, //we follow the heuristic that redundant computation is more beneficial, //i.e., we still fuse but leave the intermediate for the other consumers BinaryOp bop2 = (BinaryOp)right; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); if( bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp) { UnaryOp uop = (UnaryOp) right2; Hop uopin = uop.getInput().get(0); if( uop.getOp()==OpOp1.EXP ) { UnaryOp unary = null; //Pattern 1: (1/(1 + exp(-X)) if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) { BinaryOp bop3 = (BinaryOp) uopin; Hop left3 = bop3.getInput().get(0); Hop right3 = bop3.getInput().get(1); if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 ) unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID); } //Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by //the 'remove unnecessary minus' rewrite --> reintroduce the minus else { BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin); unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID); } if( unary != null ) { HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop); hi = unary; applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1"); } } } } //select positive (selp) operator (note: same initial pattern as sprop) if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { //by definition, either left or right or none applies. //note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial //to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation if( left instanceof BinaryOp ) //(X>0)*X { BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); Hop left2 = bleft.getInput().get(1); if( left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && left1 == right && bleft.getOp() == OpOp2.GREATER ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp1"); } } if( !applied && right instanceof BinaryOp ) //X*(X>0) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); Hop right2 = bright.getInput().get(1); if( right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && right1 == left && bright.getOp() == OpOp2.GREATER ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; applied= true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp2"); } } } //select positive (selp) operator; pattern: max(X,0) -> selp+ if( !applied && bop.getOp() == OpOp2.MAX && left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop); hi = unary; applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3"); } //select positive (selp) operator; pattern: max(0,X) -> selp+ if( !applied && bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop); hi = unary; applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4"); } } return hi; } private Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) { if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.TRACE ) //trace() { Hop hi2 = hi.getInput().get(0); if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y { Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); //create new operators (incl refresh size inside for transpose) ReorgOp trans = HopRewriteUtils.createTranspose(right); BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT); AggUnaryOp sum = HopRewriteUtils.createSum(mult); //rehang new subdag under parent node HopRewriteUtils.replaceChildReference(parent, hi, sum, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = sum; LOG.debug("Applied simplifyTraceMatrixMult"); } } return hi; } private Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) throws HopsException { //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] if( hi instanceof IndexingOp && ((IndexingOp)hi).isRowLowerEqualsUpper() && ((IndexingOp)hi).isColLowerEqualsUpper() && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) ) { Hop mm = hi.getInput().get(0); Hop X = mm.getInput().get(0); Hop Y = mm.getInput().get(1); Hop rowExpr = hi.getInput().get(1); //rl==ru Hop colExpr = hi.getInput().get(3); //cl==cu HopRewriteUtils.removeAllChildReferences(mm); //create new indexing operations IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false); ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock()); ix1.refreshSizeInformation(); IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true); ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock()); ix2.refreshSizeInformation(); //rewire matrix mult over ix1 and ix2 HopRewriteUtils.addChildReference(mm, ix1, 0); HopRewriteUtils.addChildReference(mm, ix2, 1); mm.refreshSizeInformation(); hi = mm; LOG.debug("Applied simplifySlicedMatrixMult"); } return hi; } private Hop simplifyConstantSort(Hop parent, Hop hi, int pos) throws HopsException { //order(matrix(7), indexreturn=FALSE) -> matrix(7) //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order { Hop hi2 = hi.getInput().get(0); if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==DataGenMethod.RAND && ((DataGenOp)hi2).hasConstantValue() && hi.getInput().get(3) instanceof LiteralOp ) //known indexreturn { if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) { //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2); seq.refreshSizeInformation(); HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = seq; LOG.debug("Applied simplifyConstantSort1."); } else { //order(matrix(7), indexreturn=FALSE) -> matrix(7) HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = hi2; LOG.debug("Applied simplifyConstantSort2."); } } } return hi; } private Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) throws HopsException { //order(seq(2,N+1,1), indexreturn=FALSE) -> matrix(7) //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order { Hop hi2 = hi.getInput().get(0); if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==DataGenMethod.SEQ ) { Hop incr = hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR)); //check for known ascending ordering and known indexreturn if( incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1 && hi.getInput().get(2) instanceof LiteralOp //decreasing && hi.getInput().get(3) instanceof LiteralOp ) //indexreturn { if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET, ASC/DESC { //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc); seq.refreshSizeInformation(); HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = seq; LOG.debug("Applied simplifyOrderedSort1."); } else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC { //order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1) HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = hi2; LOG.debug("Applied simplifyOrderedSort2."); } } } } return hi; } /** * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C) * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) throws HopsException { if( HopRewriteUtils.isTransposeOperation(hi) && hi.getInput().get(0) instanceof BinaryOp //basic binary && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) { Hop left = hi.getInput().get(0).getInput().get(0); Hop C = hi.getInput().get(0).getInput().get(1); //check matrix mult and both inputs transposes w/ single consumer if( left instanceof AggBinaryOp && C.getDataType().isMatrix() && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) && left.getInput().get(0).getParent().size()==1 && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) && left.getInput().get(1).getParent().size()==1 ) { Hop A = left.getInput().get(0).getInput().get(0); Hop B = left.getInput().get(1).getInput().get(0); AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A); ReorgOp rop = HopRewriteUtils.createTranspose(C); BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS); HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); hi = bop; LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+")."); } } return hi; } /** * Pattners: t(t(X)) -> X, rev(rev(X)) -> X * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) { ReOrgOp[] lookup = new ReOrgOp[]{ReOrgOp.TRANSPOSE, ReOrgOp.REV}; if( hi instanceof ReorgOp && HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg { ReOrgOp firstOp = ((ReorgOp)hi).getOp(); Hop hi2 = hi.getInput().get(0); if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==firstOp ) //second reorg w/ same type { Hop hi3 = hi2.getInput().get(0); //remove unnecessary chain of t(t()) HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = hi3; LOG.debug("Applied removeUnecessaryReorgOperation."); } } return hi; } private Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) throws HopsException { if( hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MINUS //first minus && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 ) { Hop hi2 = hi.getInput().get(1); if( hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp && ((BinaryOp)hi2).getOp()==OpOp2.MINUS //second minus && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 ) { Hop hi3 = hi2.getInput().get(1); //remove unnecessary chain of -(-()) HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = hi3; LOG.debug("Applied removeUnecessaryMinus"); } } return hi; } private Hop simplifyGroupedAggregate(Hop hi) { if( hi instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hi).getOp()==ParamBuiltinOp.GROUPEDAGG ) //aggregate { ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi; if( phi.isCountFunction() //aggregate(fn="count") && phi.getTargetHop().getDim2()==1 ) //only for vector { HashMap<String, Integer> params = phi.getParamIndexMap(); int ix1 = params.get(Statement.GAGG_TARGET); int ix2 = params.get(Statement.GAGG_GROUPS); //check for unnecessary memory consumption for "count" if( ix1 != ix2 && phi.getInput().get(ix1)!=phi.getInput().get(ix2) ) { Hop th = phi.getInput().get(ix1); Hop gh = phi.getInput().get(ix2); HopRewriteUtils.replaceChildReference(hi, th, gh, ix1); LOG.debug("Applied simplifyGroupedAggregateCount"); } } } return hi; } private Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int pos) throws HopsException { //pattern X - (s * ppred(X,0,!=)) -> X -nz s //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate for X - tmp if X is sparse if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) ) { Hop X = hi.getInput().get(0); Hop s = hi.getInput().get(1).getInput().get(0); Hop pred = hi.getInput().get(1).getInput().get(1); if( s.getDataType()==DataType.SCALAR && pred.getDataType()==DataType.MATRIX && HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = HopRewriteUtils.createBinary(X, s, OpOp2.MINUS_NZ); //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")"); } } return hi; } private Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos) throws HopsException { //pattern ppred(X,0,"!=")*log(X) -> log_nz(X) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ); //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied fuseLogNzUnaryOperation (line "+hi.getBeginLine()+")."); } } return hi; } private Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos) throws HopsException { //pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); Hop log = hi.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = HopRewriteUtils.createBinary(X, log, OpOp2.LOG_NZ); //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")"); } } return hi; } private Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) throws HopsException { //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) //note: this rewrite supports both left/right sequence if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuterVectorOperator() ) { if( ( HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: outer(v, t(seq(1,m)), "==") && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) || HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==") { //determine variable parameters for pattern a/b boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)); boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)); Hop trgt = isPatternB ? (isTransposeRight ? hi.getInput().get(1).getInput().get(0) : //get v from t(v) HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v') hi.getInput().get(0); //get v directly Hop seq = isPatternB ? hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0); String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols"; //setup input parameter hops HashMap<String,Hop> inputargs = new HashMap<String,Hop>(); inputargs.put("target", trgt); inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMaxLiteral(seq)); inputargs.put("dir", new LiteralOp(direction)); inputargs.put("ignore", new LiteralOp(true)); inputargs.put("cast", new LiteralOp(false)); //create new hop ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE, ParamBuiltinOp.REXPAND, inputargs); pbop.setOutputBlocksizes(hi.getRowsInBlock(), hi.getColsInBlock()); pbop.refreshSizeInformation(); //relink new hop into original position HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); hi = pbop; LOG.debug("Applied simplifyOuterSeqExpand (line "+hi.getBeginLine()+")"); } } return hi; } /** * NOTE: currently disabled since this rewrite is INVALID in the * presence of NaNs (because (NaN!=NaN) is true). * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ @SuppressWarnings("unused") private Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); Hop datagen = null; //ppred(X,X,"==") -> matrix(1, rows=nrow(X),cols=nrow(Y)) if( left==right && bop.getOp()==OpOp2.EQUAL || bop.getOp()==OpOp2.GREATEREQUAL || bop.getOp()==OpOp2.LESSEQUAL ) datagen = HopRewriteUtils.createDataGenOp(left, 1); //ppred(X,X,"!=") -> matrix(0, rows=nrow(X),cols=nrow(Y)) if( left==right && bop.getOp()==OpOp2.NOTEQUAL || bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS ) datagen = HopRewriteUtils.createDataGenOp(left, 0); if( datagen != null ) { HopRewriteUtils.replaceChildReference(parent, hi, datagen, pos); hi = datagen; } } return hi; } }