/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.hops.rewrite; import java.util.ArrayList; import java.util.HashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.ibm.bi.dml.hops.AggBinaryOp; import com.ibm.bi.dml.hops.AggUnaryOp; import com.ibm.bi.dml.hops.BinaryOp; import com.ibm.bi.dml.hops.DataGenOp; import com.ibm.bi.dml.hops.Hop; import com.ibm.bi.dml.hops.Hop.OpOp1; import com.ibm.bi.dml.hops.Hop.OpOp4; import com.ibm.bi.dml.hops.IndexingOp; import com.ibm.bi.dml.hops.QuaternaryOp; import com.ibm.bi.dml.hops.TernaryOp; import com.ibm.bi.dml.hops.UnaryOp; import com.ibm.bi.dml.hops.Hop.AggOp; import com.ibm.bi.dml.hops.Hop.DataGenMethod; import com.ibm.bi.dml.hops.Hop.Direction; import com.ibm.bi.dml.hops.Hop.ParamBuiltinOp; import com.ibm.bi.dml.hops.Hop.ReOrgOp; import com.ibm.bi.dml.hops.HopsException; import com.ibm.bi.dml.hops.LiteralOp; import com.ibm.bi.dml.hops.Hop.OpOp2; import com.ibm.bi.dml.hops.ParameterizedBuiltinOp; import com.ibm.bi.dml.hops.ReorgOp; import com.ibm.bi.dml.lops.MapMultChain.ChainType; import com.ibm.bi.dml.parser.DataExpression; import com.ibm.bi.dml.parser.Statement; import com.ibm.bi.dml.parser.Expression.DataType; import com.ibm.bi.dml.parser.Expression.ValueType; /** * Rule: Algebraic Simplifications. Simplifies binary expressions * in terms of two major purposes: (1) rewrite binary operations * to unary operations when possible (in CP this reduces the memory * estimate, in MR this allows map-only operations and hence prevents * unnecessary shuffle and sort) and (2) remove binary operations that * are in itself are unnecessary (e.g., *1 and /1). * */ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationStatic.class.getName()); private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MINUS}; private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT}; private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV}; @Override public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { if( roots == null ) return roots; //one pass rewrite-descend (rewrite created pattern) for( Hop h : roots ) rule_AlgebraicSimplification( h, false ); Hop.resetVisitStatus(roots); //one pass descend-rewrite (for rollup) for( Hop h : roots ) rule_AlgebraicSimplification( h, true ); return roots; } @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { if( root == null ) return root; //one pass rewrite-descend (rewrite created pattern) rule_AlgebraicSimplification( root, false ); root.resetVisitStatus(); //one pass descend-rewrite (for rollup) rule_AlgebraicSimplification( root, true ); return root; } /** * Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however, * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should * come before constant folding while the other simplifications should come after constant * folding. Hence, not applied yet. * * @throws HopsException */ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) throws HopsException { if(hop.getVisited() == Hop.VisitStatus.DONE) return; //recursively process children for( int i=0; i<hop.getInput().size(); i++) { Hop hi = hop.getInput().get(i); //process childs recursively first (to allow roll-up) if( descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); //see below //apply actual simplification rewrites (of childs incl checks) hi = removeUnnecessaryVectorizeOperation(hi); //e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize) hi = fuseDatagenAndBinaryOperation(hop, hi, i); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7) hi = fuseDatagenAndMinusOperation(hop, hi, i); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2) hi = simplifyBinaryToUnaryOperation(hi); //e.g., X*X -> X^2 (pow2) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq; hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq; hi = removeUnnecessaryTranspose(hop, hi, i); //e.g., t(t(X))->X; potentially introduced by diag/trace_MM hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count") hi = simplifyWeightedSquaredLoss(hop, hi, i); //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true) hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type) hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left) hi = simplifyWeightedCrossEntropy(hop, hi, i); //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V)) hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); } hop.setVisited(Hop.VisitStatus.DONE); } /** * * @param hi */ private Hop removeUnnecessaryVectorizeOperation(Hop hi) { //applies to all binary matrix operations, if one input is unnecessarily vectorized if( hi instanceof BinaryOp && hi.getDataType()==DataType.MATRIX && ((BinaryOp)hi).supportsMatrixScalarOperations() ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //NOTE: these rewrites of binary cell operations need to be aware that right is //potentially a vector but the result is of the size of left //TODO move to dynamic rewrites (since size dependent to account for mv binary cell and outer operations) if( !(left.getDim1()>1 && left.getDim2()==1 && right.getDim1()==1 && right.getDim2()>1) ) // no outer { //check and remove right vectorized scalar if( left.getDataType() == DataType.MATRIX && right instanceof DataGenOp ) { DataGenOp dright = (DataGenOp) right; if( dright.getOp()==DataGenMethod.RAND && dright.hasConstantValue() ) { Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN)); HopRewriteUtils.removeChildReference(bop, dright); HopRewriteUtils.addChildReference(bop, drightIn, 1); //cleanup if only consumer of intermediate if( dright.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( dright ); LOG.debug("Applied removeUnnecessaryVectorizeOperation1"); } } //check and remove left vectorized scalar else if( right.getDataType() == DataType.MATRIX && left instanceof DataGenOp ) { DataGenOp dleft = (DataGenOp) left; if( dleft.getOp()==DataGenMethod.RAND && dleft.hasConstantValue() && (left.getDim2()==1 || right.getDim2()>1) && (left.getDim1()==1 || right.getDim1()>1)) { Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN)); HopRewriteUtils.removeChildReference(bop, dleft); HopRewriteUtils.addChildReference(bop, dleftIn, 0); //cleanup if only consumer of intermediate if( dleft.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( dleft ); LOG.debug("Applied removeUnnecessaryVectorizeOperation2"); } } //Note: we applied this rewrite to at most one side in order to keep the //output semantically equivalent. However, future extensions might consider //to remove vectors from both side, compute the binary op on scalars and //finally feed it into a datagenop of the original dimensions. } } return hi; } /** * handle removal of unnecessary binary operations * * X/1 or X*1 or 1*X or X-0 -> X * -1*X or X*-1-> -X * * @param parent * @param hi * @param pos * @throws HopsException */ private Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //X/1 or X*1 -> X if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT ) { HopRewriteUtils.removeChildReference(parent, bop); HopRewriteUtils.addChildReference(parent, left, pos); hi = left; LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")"); } } //X-0 -> X else if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==0.0 ) { if( bop.getOp()==OpOp2.MINUS ) { HopRewriteUtils.removeChildReference(parent, bop); HopRewriteUtils.addChildReference(parent, left, pos); hi = left; LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line "+bop.getBeginLine()+")"); } } //1*X -> X else if( right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.MULT ) { HopRewriteUtils.removeChildReference(parent, bop); HopRewriteUtils.addChildReference(parent, right, pos); hi = right; LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line "+bop.getBeginLine()+")"); } } //-1*X -> -X //note: this rewrite is necessary since the new antlr parser always converts //-X to -1*X due to mechanical reasons else if( right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==-1.0 ) { if( bop.getOp()==OpOp2.MULT ) { bop.setOp(OpOp2.MINUS); HopRewriteUtils.removeChildReferenceByPos(bop, left, 0); HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0); hi = bop; LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line "+bop.getBeginLine()+")"); } } //X*-1 -> -X (see comment above) else if( left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==-1.0 ) { if( bop.getOp()==OpOp2.MULT ) { bop.setOp(OpOp2.MINUS); HopRewriteUtils.removeChildReferenceByPos(bop, right, 1); HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0); hi = bop; LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line "+bop.getBeginLine()+")"); } } } return hi; } /** * handle removal of unnecessary binary operations over rand data * * rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop fuseDatagenAndBinaryOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //left input rand and hence output matrix double, right scalar literal if( left instanceof DataGenOp && ((DataGenOp)left).getOp()==DataGenMethod.RAND && right instanceof LiteralOp ) { DataGenOp inputGen = (DataGenOp)left; HashMap<String,Integer> params = inputGen.getParamIndexMap(); Hop min = left.getInput().get(params.get(DataExpression.RAND_MIN)); Hop max = left.getInput().get(params.get(DataExpression.RAND_MAX)); double sval = ((LiteralOp)right).getDoubleValue(); if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp ) { //create fused data gen operator DataGenOp gen = null; if( bop.getOp()==OpOp2.MULT ) gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); else //if( bop.getOp()==OpOp2.PLUS ) gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval); //rewire parents HopRewriteUtils.removeChildReference(parent, bop); HopRewriteUtils.addChildReference(parent, gen, pos); //propagate potentially updated nnz=0 parent.refreshSizeInformation(); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation1"); } } //right input rand and hence output matrix double, left scalar literal else if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND && left instanceof LiteralOp ) { DataGenOp inputGen = (DataGenOp)right; HashMap<String,Integer> params = inputGen.getParamIndexMap(); Hop min = right.getInput().get(params.get(DataExpression.RAND_MIN)); Hop max = right.getInput().get(params.get(DataExpression.RAND_MAX)); double sval = ((LiteralOp)left).getDoubleValue(); if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp ) { //create fused data gen operator DataGenOp gen = null; if( bop.getOp()==OpOp2.MULT ) gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); else //if( bop.getOp()==OpOp2.PLUS ) gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval); //rewire parents HopRewriteUtils.removeChildReference(parent, bop); HopRewriteUtils.addChildReference(parent, gen, pos); //propagate potentially updated nnz=0 parent.refreshSizeInformation(); hi = gen; LOG.debug("Applied fuseDatagenAndBinaryOperation2"); } } } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop fuseDatagenAndMinusOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==0.0 ) { DataGenOp inputGen = (DataGenOp)right; HashMap<String,Integer> params = inputGen.getParamIndexMap(); int ixMin = params.get(DataExpression.RAND_MIN); int ixMax = params.get(DataExpression.RAND_MAX); Hop min = right.getInput().get(ixMin); Hop max = right.getInput().get(ixMax); //apply rewrite under additional conditions (for simplicity) if( inputGen.getParent().size()==1 && min instanceof LiteralOp && max instanceof LiteralOp ) { //exchange and *-1 (special case 0 stays 0 instead of -0 for consistency) double newMinVal = (((LiteralOp)max).getDoubleValue()==0)?0:(-1 * ((LiteralOp)max).getDoubleValue()); double newMaxVal = (((LiteralOp)min).getDoubleValue()==0)?0:(-1 * ((LiteralOp)min).getDoubleValue()); Hop newMin = new LiteralOp(newMinVal); Hop newMax = new LiteralOp(newMaxVal); HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin); HopRewriteUtils.addChildReference(inputGen, newMin, ixMin); HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax); HopRewriteUtils.addChildReference(inputGen, newMax, ixMax); HopRewriteUtils.removeChildReference(parent, bop); HopRewriteUtils.addChildReference(parent, inputGen, pos); hi = inputGen; LOG.debug("Applied fuseDatagenAndMinusOperation"); } } } return hi; } /** * handle simplification of binary operations * (relies on previous common subexpression elimination) * * X+X -> X*2 or X*X -> X^2 */ private Hop simplifyBinaryToUnaryOperation( Hop hi ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); if( left == right && left.getDataType()==DataType.MATRIX ) { //note: we simplify this to unary operations first (less mem and better MR plan), //however, we later compile specific LOPS for X*2 and X^2 if( bop.getOp()==OpOp2.PLUS ) //X+X -> X*2 { bop.setOp(OpOp2.MULT); LiteralOp tmp = new LiteralOp(2); bop.getInput().remove(1); right.getParent().remove(bop); HopRewriteUtils.addChildReference(hi, tmp, 1); LOG.debug("Applied simplifyBinaryToUnaryOperation1"); } else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2 { bop.setOp(OpOp2.POW); LiteralOp tmp = new LiteralOp(2); bop.getInput().remove(1); right.getParent().remove(bop); HopRewriteUtils.addChildReference(hi, tmp, 1); LOG.debug("Applied simplifyBinaryToUnaryOperation2"); } } } return hi; } /** * * @param hi * @return */ private Hop simplifyMultiBinaryToBinaryOperation( Hop hi ) { //pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate) if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MINUS && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1 && hi.getInput().get(1) instanceof BinaryOp && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT && hi.getInput().get(1).getParent().size() == 1 ) //single consumer { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(1).getInput().get(0); Hop right = hi.getInput().get(1).getInput().get(1); //set new binaryop type and rewire inputs bop.setOp(OpOp2.MINUS1_MULT); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.addChildReference(bop, left); HopRewriteUtils.addChildReference(bop, right); LOG.debug("Applied simplifyMultiBinaryToBinaryOperation."); } return hi; } /** * (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X * (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X * * * @param parent * @param hi * @param pos * @return */ private Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X //(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X boolean applied = false; if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) ) { Hop X = null; Hop Y = null; if( left instanceof BinaryOp && ((BinaryOp)left).getOp()==OpOp2.MULT ) //(Y*X-X) -> (Y-1)*X { Hop leftC1 = left.getInput().get(0); Hop leftC2 = left.getInput().get(1); //System.out.println("aOp2:"+((BinaryOp)left).getOp()+": "+leftC1.getName()+" "+leftC2.getName()); if( leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX && (right == leftC1 || right == leftC2) && leftC1 !=leftC2 ){ //any mult order X = right; Y = ( right == leftC1 ) ? leftC2 : leftC1; } if( X != null ){ //rewrite 'binary +/-' HopRewriteUtils.removeChildReference(parent, hi); LiteralOp literal = new LiteralOp(1); BinaryOp plus = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), bop.getOp(), Y, literal); HopRewriteUtils.refreshOutputParameters(plus, right); BinaryOp mult = new BinaryOp(left.getName(), left.getDataType(), left.getValueType(), OpOp2.MULT, plus, X); HopRewriteUtils.refreshOutputParameters(mult, left); HopRewriteUtils.addChildReference(parent, mult, pos); hi = mult; applied = true; LOG.debug("Applied simplifyDistributiveBinaryOperation1"); } } if( !applied && right instanceof BinaryOp && ((BinaryOp)right).getOp()==OpOp2.MULT ) //(X-Y*X) -> (1-Y)*X { Hop rightC1 = right.getInput().get(0); Hop rightC2 = right.getInput().get(1); if( rightC1.getDataType()==DataType.MATRIX && rightC2.getDataType()==DataType.MATRIX && (left == rightC1 || left == rightC2) && rightC1 !=rightC2 ){ //any mult order X = left; Y = ( left == rightC1 ) ? rightC2 : rightC1; } if( X != null ){ //rewrite '+/- binary' HopRewriteUtils.removeChildReference(parent, hi); LiteralOp literal = new LiteralOp(1); BinaryOp plus = new BinaryOp(left.getName(), left.getDataType(), left.getValueType(), bop.getOp(), literal, Y); HopRewriteUtils.refreshOutputParameters(plus, left); BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, plus, X); HopRewriteUtils.refreshOutputParameters(mult, right); HopRewriteUtils.addChildReference(parent, mult, pos); hi = mult; LOG.debug("Applied simplifyDistributiveBinaryOperation2"); } } } } return hi; } /** * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v) * * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too * eagerly, which would loose additional rewrite potential. This rewrite has two goals * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees. * * @param parent * @param hi * @param pos * @return */ private Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof BinaryOp && parent instanceof AggBinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); OpOp2 op = bop.getOp(); if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY) ) { boolean applied = false; if( right instanceof BinaryOp ) { BinaryOp bop2 = (BinaryOp)right; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); OpOp2 op2 = bop2.getOp(); if( op==op2 && right2.getDataType()==DataType.MATRIX && (right2 instanceof AggBinaryOp) ) { //(X*(Y*op()) -> (X*Y)*op() HopRewriteUtils.removeChildReference(parent, bop); BinaryOp bop3 = new BinaryOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, op, left, left2); HopRewriteUtils.refreshOutputParameters(bop3, bop); BinaryOp bop4 = new BinaryOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, op, bop3, right2); HopRewriteUtils.refreshOutputParameters(bop4, bop2); HopRewriteUtils.addChildReference(parent, bop4, pos); hi = bop4; applied = true; LOG.debug("Applied simplifyBushyBinaryOperation1"); } } if( !applied && left instanceof BinaryOp ) { BinaryOp bop2 = (BinaryOp)left; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); OpOp2 op2 = bop2.getOp(); if( op==op2 && left2.getDataType()==DataType.MATRIX && (left2 instanceof AggBinaryOp) && (right2.getDim2() > 1 || right.getDim2() == 1) //X not vector, or Y vector && (right2.getDim1() > 1 || right.getDim1() == 1) ) //X not vector, or Y vector { //((op()*X)*Y) -> op()*(X*Y) HopRewriteUtils.removeChildReference(parent, bop); BinaryOp bop3 = new BinaryOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, op, right2, right); HopRewriteUtils.refreshOutputParameters(bop3, bop2); BinaryOp bop4 = new BinaryOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, op, left2, bop3); HopRewriteUtils.refreshOutputParameters(bop4, bop); HopRewriteUtils.addChildReference(parent, bop4, pos); hi = bop4; LOG.debug("Applied simplifyBushyBinaryOperation2"); } } } } return hi; } /** * * @param parent * @param hi * @param pos * @return */ private Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg && hi.getInput().get(0) instanceof ReorgOp ) //reorg operation { ReorgOp rop = (ReorgOp)hi.getInput().get(0); if( (rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE) //valid reorg && rop.getParent().size()==1 ) //uagg only reorg consumer { Hop input = rop.getInput().get(0); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.removeAllChildReferences(rop); HopRewriteUtils.addChildReference(hi, input); LOG.debug("Applied simplifyUnaryAggReorgOperation"); } } return hi; } /** * * @param parent * @param hi * @param pos * @return */ private Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos ) { //e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //t() rooted && hi.getInput().get(0) instanceof BinaryOp && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind) || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND) && hi.getInput().get(0).getParent().size() == 1 ) //single consumers of append { BinaryOp bop = (BinaryOp)hi.getInput().get(0); if( bop.getInput().get(0) instanceof ReorgOp //both inputs transpose ops && ((ReorgOp)bop.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE && bop.getInput().get(1) instanceof ReorgOp && ((ReorgOp)bop.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE ) { Hop left = bop.getInput().get(0).getInput().get(0); Hop right = bop.getInput().get(1).getInput().get(0); //rewire links from parent, transpose, and binary HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, bop, pos); HopRewriteUtils.removeAllChildReferences(hi); HopRewriteUtils.removeAllChildReferences(bop); //change append type (safe due to single parent check) if( bop.getOp()==OpOp2.CBIND ) bop.setOp(OpOp2.RBIND); else bop.setOp(OpOp2.CBIND); //relink new childs to binary op HopRewriteUtils.addChildReference(bop, left, 0); HopRewriteUtils.addChildReference(bop, right, 1); bop.refreshSizeInformation(); hi = bop; LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+")."); } } return hi; } /** * handle simplification of more complex sub DAG to unary operation. * * X*(1-X) -> sprop(X) * (1-X)*X -> sprop(X) * 1/(1+exp(-X)) -> sigmoid(X) * * @param hi * @throws HopsException */ private Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); //sample proportion (sprop) operator if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { //by definition, either left or right or none applies. //note: if there are multiple consumers on the intermediate, //we follow the heuristic that redundant computation is more beneficial, //i.e., we still fuse but leave the intermediate for the other consumers if( left instanceof BinaryOp ) //(1-X)*X { BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); Hop left2 = bleft.getInput().get(1); if( left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 && left2 == right && bleft.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP); HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); HopRewriteUtils.addChildReference(parent, unary, pos); //cleanup if only consumer of intermediate if( bop.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(bop); if( left.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(left); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1"); } } if( right instanceof BinaryOp ) //X*(1-X) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); Hop right2 = bright.getInput().get(1); if( right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 && right2 == left && bright.getOp() == OpOp2.MINUS ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP); HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); HopRewriteUtils.addChildReference(parent, unary, pos); //cleanup if only consumer of intermediate if( bop.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(bop); if( left.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(right); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2"); } } } //sigmoid operator else if( bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp) { //note: if there are multiple consumers on the intermediate, //we follow the heuristic that redundant computation is more beneficial, //i.e., we still fuse but leave the intermediate for the other consumers BinaryOp bop2 = (BinaryOp)right; Hop left2 = bop2.getInput().get(0); Hop right2 = bop2.getInput().get(1); if( bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp) { UnaryOp uop = (UnaryOp) right2; Hop uopin = uop.getInput().get(0); if( uop.getOp()==OpOp1.EXP ) { UnaryOp unary = null; //Pattern 1: (1/(1 + exp(-X)) if( uopin instanceof BinaryOp && ((BinaryOp)uopin).getOp()==OpOp2.MINUS ) { BinaryOp bop3 = (BinaryOp) uopin; Hop left3 = bop3.getInput().get(0); Hop right3 = bop3.getInput().get(1); if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 ) { unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID); } } //Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by //the 'remove unnecessary minus' rewrite --> reintroduce the minus else { BinaryOp minus = HopRewriteUtils.createMinus(uopin); unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID); } if( unary != null ) { HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); HopRewriteUtils.addChildReference(parent, unary, pos); //cleanup if only consumer of intermediate if( bop.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(bop); if( bop2.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(bop2); if( uop.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(uop); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1"); } } } } //select positive (selp) operator if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { //by definition, either left or right or none applies. //note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial //to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation if( left instanceof BinaryOp ) //(X>0)*X { BinaryOp bleft = (BinaryOp)left; Hop left1 = bleft.getInput().get(0); Hop left2 = bleft.getInput().get(1); if( left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && left1 == right && bleft.getOp() == OpOp2.GREATER ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); HopRewriteUtils.addChildReference(parent, unary, pos); //cleanup if only consumer of intermediate if( bop.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(bop); if( left.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(left); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp1"); } } if( right instanceof BinaryOp ) //X*(X>0) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); Hop right2 = bright.getInput().get(1); if( right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && right1 == left && bright.getOp() == OpOp2.GREATER ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos); HopRewriteUtils.addChildReference(parent, unary, pos); //cleanup if only consumer of intermediate if( bop.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(bop); if( left.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences(right); hi = unary; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp2"); } } } } return hi; } /** * * @param parent * @param hi * @param pos * @return */ private Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) { if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.TRACE ) //trace() { Hop hi2 = hi.getInput().get(0); if( hi2 instanceof AggBinaryOp && ((AggBinaryOp)hi2).isMatrixMultiply() ) //X%*%Y { Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); //remove link from parent to diag HopRewriteUtils.removeChildReference(parent, hi); //remove links to inputs to matrix mult //removeChildReference(hi2, left); //removeChildReference(hi2, right); //create new operators (incl refresh size inside for transpose) ReorgOp trans = HopRewriteUtils.createTranspose(right); BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, left, trans); mult.setRowsInBlock(right.getRowsInBlock()); mult.setColsInBlock(right.getColsInBlock()); mult.refreshSizeInformation(); AggUnaryOp sum = new AggUnaryOp(right.getName(), DataType.SCALAR, right.getValueType(), AggOp.SUM, Direction.RowCol, mult); sum.refreshSizeInformation(); //rehang new subdag under parent node HopRewriteUtils.addChildReference(parent, sum, pos); //cleanup if only consumer of intermediate if( hi.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( hi ); if( hi2.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( hi2 ); hi = sum; LOG.debug("Applied simplifyTraceMatrixMult"); } } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) throws HopsException { //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] if( hi instanceof IndexingOp && ((IndexingOp)hi).getRowLowerEqualsUpper() && ((IndexingOp)hi).getColLowerEqualsUpper() && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer && hi.getInput().get(0) instanceof AggBinaryOp && ((AggBinaryOp)hi.getInput().get(0)).isMatrixMultiply() ) { Hop mm = hi.getInput().get(0); Hop X = mm.getInput().get(0); Hop Y = mm.getInput().get(1); Hop rowExpr = hi.getInput().get(1); //rl==ru Hop colExpr = hi.getInput().get(3); //cl==cu HopRewriteUtils.removeAllChildReferences(mm); //create new indexing operations IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false); HopRewriteUtils.setOutputBlocksizes(ix1, X.getRowsInBlock(), X.getColsInBlock()); ix1.refreshSizeInformation(); IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true); HopRewriteUtils.setOutputBlocksizes(ix2, Y.getRowsInBlock(), Y.getColsInBlock()); ix2.refreshSizeInformation(); //rewire matrix mult over ix1 and ix2 HopRewriteUtils.addChildReference(mm, ix1, 0); HopRewriteUtils.addChildReference(mm, ix2, 1); mm.refreshSizeInformation(); hi = mm; LOG.debug("Applied simplifySlicedMatrixMult"); } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyConstantSort(Hop parent, Hop hi, int pos) throws HopsException { //order(matrix(7), indexreturn=FALSE) -> matrix(7) //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order { Hop hi2 = hi.getInput().get(0); if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==DataGenMethod.RAND && ((DataGenOp)hi2).hasConstantValue() && hi.getInput().get(3) instanceof LiteralOp ) //known indexreturn { if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) { //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2); seq.refreshSizeInformation(); HopRewriteUtils.addChildReference(parent, seq, pos); if( hi.getParent().isEmpty() ) HopRewriteUtils.removeChildReference(hi, hi2); hi = seq; LOG.debug("Applied simplifyConstantSort1."); } else { //order(matrix(7), indexreturn=FALSE) -> matrix(7) HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hi2, pos); if( hi.getParent().isEmpty() ) HopRewriteUtils.removeChildReference(hi, hi2); hi = hi2; LOG.debug("Applied simplifyConstantSort2."); } } } return hi; } private Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) throws HopsException { //order(seq(2,N+1,1), indexreturn=FALSE) -> matrix(7) //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order { Hop hi2 = hi.getInput().get(0); if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==DataGenMethod.SEQ ) { Hop incr = hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR)); //check for known ascending ordering and known indexreturn if( incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1 && hi.getInput().get(2) instanceof LiteralOp //decreasing && hi.getInput().get(3) instanceof LiteralOp ) //indexreturn { if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET, ASC/DESC { //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc); seq.refreshSizeInformation(); HopRewriteUtils.addChildReference(parent, seq, pos); if( hi.getParent().isEmpty() ) HopRewriteUtils.removeChildReference(hi, hi2); hi = seq; LOG.debug("Applied simplifyOrderedSort1."); } else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC { //order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1) HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hi2, pos); if( hi.getParent().isEmpty() ) HopRewriteUtils.removeChildReference(hi, hi2); hi = hi2; LOG.debug("Applied simplifyOrderedSort2."); } } } } return hi; } /** * * @param parent * @param hi * @param pos */ private Hop removeUnnecessaryTranspose(Hop parent, Hop hi, int pos) { if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE ) //first transpose { Hop hi2 = hi.getInput().get(0); if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.TRANSPOSE ) //second transpose { Hop hi3 = hi2.getInput().get(0); //remove unnecessary chain of t(t()) HopRewriteUtils.removeChildReference(parent, hi); HopRewriteUtils.addChildReference(parent, hi3, pos); hi = hi3; //cleanup if only consumer of intermediate if( hi.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( hi ); if( hi2.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( hi2 ); LOG.debug("Applied removeUnecessaryTranspose"); } } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) throws HopsException { if( hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MINUS //first minus && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 ) { Hop hi2 = hi.getInput().get(1); if( hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp && ((BinaryOp)hi2).getOp()==OpOp2.MINUS //second minus && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 ) { Hop hi3 = hi2.getInput().get(1); //remove unnecessary chain of -(-()) HopRewriteUtils.removeChildReference(parent, hi); HopRewriteUtils.addChildReference(parent, hi3, pos); hi = hi3; //cleanup if only consumer of intermediate if( hi.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( hi ); if( hi2.getParent().isEmpty() ) HopRewriteUtils.removeAllChildReferences( hi2 ); LOG.debug("Applied removeUnecessaryMinus"); } } return hi; } /** * * @param hi * @return */ private Hop simplifyGroupedAggregate(Hop hi) { if( hi instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hi).getOp()==ParamBuiltinOp.GROUPEDAGG ) //aggregate { ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi; if( phi.isCountFunction() ) //aggregate(fn="count") { HashMap<String, Integer> params = phi.getParamIndexMap(); int ix1 = params.get(Statement.GAGG_TARGET); int ix2 = params.get(Statement.GAGG_GROUPS); //check for unnecessary memory consumption for "count" if( ix1 != ix2 && phi.getInput().get(ix1)!=phi.getInput().get(ix2) ) { Hop th = phi.getInput().get(ix1); Hop gh = phi.getInput().get(ix2); HopRewriteUtils.removeChildReference(hi, th); HopRewriteUtils.addChildReference(hi, gh, ix1); LOG.debug("Applied simplifyGroupedAggregateCount"); } } } return hi; } /** * Searches for weighted squared loss expressions and replaces them with a quaternary operator. * Currently, this search includes the following three patterns: * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) * * NOTE: We include transpose into the pattern because during runtime we need to compute * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation * without additional memory requirements for internal transpose. * * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon * as we generalized the runtime or hop/lop compilation. * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) throws HopsException { //NOTE: there might be also a general simplification without custom operator //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2 Hop hnew = null; if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol && ((AggUnaryOp)hi).getOp() == AggOp.SUM //all patterns rooted by sum() && hi.getInput().get(0) instanceof BinaryOp //all patterns subrooted by binary op && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult { BinaryOp bop = (BinaryOp) hi.getInput().get(0); boolean appliedPattern = false; //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2) if( bop.getOp()==OpOp2.MULT && bop.getInput().get(1) instanceof BinaryOp && bop.getInput().get(0).getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv && ((BinaryOp)bop.getInput().get(1)).getOp()==OpOp2.POW && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2) { Hop W = bop.getInput().get(0); Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V)) if( tmp instanceof BinaryOp && ((BinaryOp)tmp).getOp()==OpOp2.MINUS && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv && tmp.getInput().get(0).getDataType() == DataType.MATRIX ) { //a) sum (W * (X - U %*% t(V)) ^ 2) int uvIndex = -1; if( tmp.getInput().get(1) instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { uvIndex = 1; } //b) sum (W * (U %*% t(V) - X) ^ 2) else if(tmp.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { uvIndex = 0; } if( uvIndex >= 0 ) //rewrite match { Hop X = tmp.getInput().get((uvIndex==0)?1:0); Hop U = tmp.getInput().get(uvIndex).getInput().get(0); Hop V = tmp.getInput().get(uvIndex).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) { V = HopRewriteUtils.createTranspose(V); } else{ V = V.getInput().get(0); } //handle special case of post_nz if( HopRewriteUtils.isNonZeroIndicator(W, X) ){ W = new LiteralOp(1); } //construct quaternary hop hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true); HopRewriteUtils.setOutputParametersForScalar(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")"); } } } //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2) if( !appliedPattern && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2 && bop.getInput().get(0) instanceof BinaryOp && bop.getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX) { Hop lleft = bop.getInput().get(0).getInput().get(0); Hop lright = bop.getInput().get(0).getInput().get(1); //a) sum ((X - W * (U %*% t(V))) ^ 2) int wuvIndex = -1; if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){ wuvIndex = 1; } //b) sum ((W * (U %*% t(V)) - X) ^ 2) else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){ wuvIndex = 0; } if( wuvIndex >= 0 ) //rewrite match { Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0); Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V))) if( ((BinaryOp)tmp).getOp()==OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { Hop W = tmp.getInput().get(0); Hop U = tmp.getInput().get(1).getInput().get(0); Hop V = tmp.getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) { V = HopRewriteUtils.createTranspose(V); } else { V = V.getInput().get(0); } hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); HopRewriteUtils.setOutputParametersForScalar(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")"); } } } //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) //alternative pattern: sum (((U %*% t(V)) - X) ^ 2) if( !appliedPattern && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2 && bop.getInput().get(0) instanceof BinaryOp && bop.getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX) { Hop lleft = bop.getInput().get(0).getInput().get(0); Hop lright = bop.getInput().get(0).getInput().get(1); //a) sum ((X - (U %*% t(V))) ^ 2) int uvIndex = -1; if( lright instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { uvIndex = 1; } //b) sum (((U %*% t(V)) - X) ^ 2) else if( lleft instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { uvIndex = 0; } if( uvIndex >= 0 ) //rewrite match { Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0); Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V)) Hop W = new LiteralOp(1); //no weighting Hop U = tmp.getInput().get(0); Hop V = tmp.getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) { V = HopRewriteUtils.createTranspose(V); } else { V = V.getInput().get(0); } hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); HopRewriteUtils.setOutputParametersForScalar(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")"); } } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hnew, pos); hi = hnew; } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; if( hi instanceof BinaryOp //all patterns subrooted by W * && ((BinaryOp) hi).getOp()==OpOp2.MULT && hi.getDim2() > 1 //not applied for vector-vector mult && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof UnaryOp ) //sigmoid/log { UnaryOp uop = (UnaryOp) hi.getInput().get(1); boolean appliedPattern = false; //Pattern 1) W * sigmoid(Y%*%t(X)) (basic) if( uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) ) { Hop W = hi.getInput().get(0); Hop Y = uop.getInput().get(0).getInput().get(0); Hop tX = uop.getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid1 (line "+hi.getBeginLine()+")"); } //Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus) if( !appliedPattern && uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof BinaryOp && ((BinaryOp)uop.getInput().get(0)).getOp()==OpOp2.MINUS && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe( (LiteralOp)uop.getInput().get(0).getInput().get(0))==0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true)) { Hop W = hi.getInput().get(0); Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0); Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid2 (line "+hi.getBeginLine()+")"); } //Pattern 3) W * log(sigmoid(Y%*%t(X))) (log) if( !appliedPattern && uop.getOp() == OpOp1.LOG && uop.getInput().get(0) instanceof UnaryOp && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) ) { Hop W = hi.getInput().get(0); Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0); Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid3 (line "+hi.getBeginLine()+")"); } //Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus) if( !appliedPattern && uop.getOp() == OpOp1.LOG && uop.getInput().get(0) instanceof UnaryOp && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID && uop.getInput().get(0).getInput().get(0) instanceof BinaryOp ) { BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0); if( bop.getOp() == OpOp2.MINUS && bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true)) { Hop W = hi.getInput().get(0); Hop Y = bop.getInput().get(1).getInput().get(0); Hop tX = bop.getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid4 (line "+hi.getBeginLine()+")"); } } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hnew, pos); hi = hnew; } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; boolean appliedPattern = false; //left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)' //note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops) if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() && (hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) || hi.getInput().get(1) instanceof BinaryOp && hi.getDim2() > 1 //not applied for vector-vector mult && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) ) { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); //Pattern 1) t(U) %*% (W/(U%*%t(V))) //alternative pattern: t(U) %*% (W*(U%*%t(V))) if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv && right.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); Hop U = right.getInput().get(1).getInput().get(0); Hop V = right.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(left, U) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT; hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, 1, mult, false); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); //add output transpose for efficient target indexing (redundant t() removed by other rewrites) hnew = HopRewriteUtils.createTranspose(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")"); } } //Pattern 2) (W/(U%*%t(V))) %*% V //alternative pattern: (W*(U%*%t(V))) %*% V if( !appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv && left.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); Hop U = left.getInput().get(1).getInput().get(0); Hop V = left.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(right, V) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = right; else V = V.getInput().get(0); boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT; hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, 2, mult, false); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")"); } } //Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X)) if( right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = right.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint && HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, 1, true, true); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); //add output transpose for efficient target indexing (redundant t() removed by other rewrites) hnew = HopRewriteUtils.createTranspose(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")"); } } //Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V if( !appliedPattern && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = left.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint && HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint { if( !HopRewriteUtils.isTransposeOperation(V) ) V = right; else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, 2, true, true); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")"); } } } //Pattern 5) (W*(U%*%t(V))) if( !appliedPattern && hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv && hi.getDim2() > 1 //not applied for vector-vector mult && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof AggBinaryOp && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = hi.getInput().get(0); Hop U = hi.getInput().get(1).getInput().get(0); Hop V = hi.getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, 0, true, false); HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")"); } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hnew, pos); hi = hnew; } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; //Pattern 1) sum( X * log(U %*% t(V))) if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol && ((AggUnaryOp)hi).getOp() == AggOp.SUM //pattern rooted by sum() && hi.getInput().get(0) instanceof BinaryOp //pattern subrooted by binary op && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult { BinaryOp bop = (BinaryOp) hi.getInput().get(0); Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) //prevent mb && right instanceof UnaryOp && ((UnaryOp)right).getOp()==OpOp1.LOG && right.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { Hop X = left; Hop U = right.getInput().get(0).getInput().get(0); Hop V = right.getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V); HopRewriteUtils.setOutputBlocksizes(hnew, X.getRowsInBlock(), X.getColsInBlock()); LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")"); } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hnew, pos); hi = hnew; } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int pos) throws HopsException { //pattern X - (s * ppred(X,0,!=)) -> X -nz s //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate for X - tmp if X is sparse if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MINUS && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT ) { Hop X = hi.getInput().get(0); Hop s = hi.getInput().get(1).getInput().get(0); Hop pred = hi.getInput().get(1).getInput().get(1); if( s.getDataType()==DataType.SCALAR && pred.getDataType()==DataType.MATRIX && pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = new BinaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, OpOp2.MINUS_NZ, X, s); HopRewriteUtils.setOutputBlocksizes(hnew, hi.getRowsInBlock(), hi.getColsInBlock()); hnew.refreshSizeInformation(); //relink new hop into original position HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hnew, pos); hi = hnew; LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")"); } } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos) throws HopsException { //pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) //note: this is done as a hop rewrite in order to significantly reduce the //memory estimate and to prevent dense intermediates if X is ultra sparse if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MULT && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.LOG ) { Hop pred = hi.getInput().get(0); Hop X = hi.getInput().get(1).getInput().get(0); Hop log = hi.getInput().get(1).getInput().get(1); if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL && pred.getInput().get(0) == X //depend on common subexpression elimination && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) { Hop hnew = new BinaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, OpOp2.LOG_NZ, X, log); HopRewriteUtils.setOutputBlocksizes(hnew, hi.getRowsInBlock(), hi.getColsInBlock()); hnew.refreshSizeInformation(); //relink new hop into original position HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, hnew, pos); hi = hnew; LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")"); } } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) throws HopsException { //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) //note: this rewrite supports both left/right sequence if( hi instanceof BinaryOp && ((BinaryOp)hi).isOuterVectorOperator() && ((BinaryOp)hi).getOp()==OpOp2.EQUAL ) { if( ( hi.getInput().get(1) instanceof ReorgOp //pattern a: outer(v, t(seq(1,m)), "==") && ((ReorgOp) hi.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) || HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==") { //determine variable parameters for pattern a/b boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)); boolean isTransposeRight = (hi.getInput().get(1) instanceof ReorgOp && ((ReorgOp) hi.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE); Hop trgt = isPatternB ? (isTransposeRight ? hi.getInput().get(1).getInput().get(0) : //get v from t(v) HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v') hi.getInput().get(0); //get v directly Hop seq = isPatternB ? hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0); String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols"; //setup input parameter hops HashMap<String,Hop> inputargs = new HashMap<String,Hop>(); inputargs.put("target", trgt); inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMaxLiteral(seq)); inputargs.put("dir", new LiteralOp(direction)); inputargs.put("ignore", new LiteralOp(true)); inputargs.put("cast", new LiteralOp(false)); //create new hop ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE, ParamBuiltinOp.REXPAND, inputargs); HopRewriteUtils.setOutputBlocksizes(pbop, hi.getRowsInBlock(), hi.getColsInBlock()); pbop.refreshSizeInformation(); //relink new hop into original position HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, pbop, pos); hi = pbop; LOG.debug("Applied simplifyOuterSeqExpand (line "+hi.getBeginLine()+")"); } } return hi; } /** * * @param parent * @param hi * @param pos * @return * @throws HopsException */ private Hop simplifyTableSeqExpand(Hop parent, Hop hi, int pos) throws HopsException { //pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) //note: this rewrite supports both left/right sequence if( hi instanceof TernaryOp && hi.getInput().size()==5 && hi.getInput().get(3) instanceof LiteralOp && hi.getInput().get(4) instanceof LiteralOp ) { if( (HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) && hi.getInput().get(3) instanceof LiteralOp) //pattern a: table(seq(1,nrow(v)), v, nrow(v), m) ||(HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1)) && hi.getInput().get(2) instanceof LiteralOp) ) //pattern b: table(v, seq(1,nrow(v)), m, nrow(v)) { //determine variable parameters for pattern a/b int ixTgt = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? 1 : 0; int ixMax = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? 4 : 3; String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "cols" : "rows"; //setup input parameter hops HashMap<String,Hop> inputargs = new HashMap<String,Hop>(); inputargs.put("target", hi.getInput().get(ixTgt)); inputargs.put("max", hi.getInput().get(ixMax)); inputargs.put("dir", new LiteralOp(direction)); inputargs.put("ignore", new LiteralOp(false)); inputargs.put("cast", new LiteralOp(true)); //create new hop ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE, ParamBuiltinOp.REXPAND, inputargs); HopRewriteUtils.setOutputBlocksizes(pbop, hi.getRowsInBlock(), hi.getColsInBlock()); pbop.refreshSizeInformation(); //relink new hop into original position HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, pbop, pos); hi = pbop; LOG.debug("Applied simplifyTableSeqExpand (line "+hi.getBeginLine()+")"); } } return hi; } /** * NOTE: currently disabled since this rewrite is INVALID in the * presence of NaNs (because (NaN!=NaN) is true). * * @param parent * @param hi * @param pos * @return * @throws HopsException */ @SuppressWarnings("unused") private Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof BinaryOp ) { BinaryOp bop = (BinaryOp)hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); Hop datagen = null; //ppred(X,X,"==") -> matrix(1, rows=nrow(X),cols=nrow(Y)) if( left==right && bop.getOp()==OpOp2.EQUAL || bop.getOp()==OpOp2.GREATEREQUAL || bop.getOp()==OpOp2.LESSEQUAL ) datagen = HopRewriteUtils.createDataGenOp(left, 1); //ppred(X,X,"!=") -> matrix(0, rows=nrow(X),cols=nrow(Y)) if( left==right && bop.getOp()==OpOp2.NOTEQUAL || bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS ) datagen = HopRewriteUtils.createDataGenOp(left, 0); if( datagen != null ) { HopRewriteUtils.removeChildReference(parent, hi); HopRewriteUtils.addChildReference(parent, datagen, pos); hi = datagen; } } return hi; } }