/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.Hop.ReOrgOp;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
/**
* Rule: Algebraic Simplifications. Simplifies binary expressions
* in terms of two major purposes: (1) rewrite binary operations
* to unary operations when possible (in CP this reduces the memory
* estimate, in MR this allows map-only operations and hence prevents
* unnecessary shuffle and sort) and (2) remove binary operations that
* are in itself are unnecessary (e.g., *1 and /1).
*
*/
public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
{
private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationStatic.class.getName());
//valid aggregation operation types for rowOp to colOp conversions and vice versa
private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR};
//valid binary operations for distributive and associate reorderings
private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MINUS};
private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT};
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return roots;
//one pass rewrite-descend (rewrite created pattern)
for( Hop h : roots )
rule_AlgebraicSimplification( h, false );
Hop.resetVisitStatus(roots);
//one pass descend-rewrite (for rollup)
for( Hop h : roots )
rule_AlgebraicSimplification( h, true );
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return root;
//one pass rewrite-descend (rewrite created pattern)
rule_AlgebraicSimplification( root, false );
root.resetVisitStatus();
//one pass descend-rewrite (for rollup)
rule_AlgebraicSimplification( root, true );
return root;
}
/**
* Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however,
* (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should
* come before constant folding while the other simplifications should come after constant
* folding. Hence, not applied yet.
*
* @param hop high-level operator
* @param descendFirst if process children recursively first
* @throws HopsException if HopsException occurs
*/
private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
throws HopsException
{
if(hop.isVisited())
return;
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
Hop hi = hop.getInput().get(i);
//process childs recursively first (to allow roll-up)
if( descendFirst )
rule_AlgebraicSimplification(hi, descendFirst); //see below
//apply actual simplification rewrites (of childs incl checks)
hi = removeUnnecessaryVectorizeOperation(hi); //e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X
hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize)
hi = fuseDatagenAndBinaryOperation(hi); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7)
hi = fuseDatagenAndMinusOperation(hi); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2)
hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps)
hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X)
hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X))
hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X)
hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X)
hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y));
hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites
hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
hi = fuseLogNzUnaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X) -> log_nz(X)
hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
}
hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);
}
hop.setVisited();
}
private Hop removeUnnecessaryVectorizeOperation(Hop hi)
{
//applies to all binary matrix operations, if one input is unnecessarily vectorized
if( hi instanceof BinaryOp && hi.getDataType()==DataType.MATRIX
&& ((BinaryOp)hi).supportsMatrixScalarOperations() )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//NOTE: these rewrites of binary cell operations need to be aware that right is
//potentially a vector but the result is of the size of left
//TODO move to dynamic rewrites (since size dependent to account for mv binary cell and outer operations)
if( !(left.getDim1()>1 && left.getDim2()==1 && right.getDim1()==1 && right.getDim2()>1) ) // no outer
{
//check and remove right vectorized scalar
if( left.getDataType() == DataType.MATRIX && right instanceof DataGenOp )
{
DataGenOp dright = (DataGenOp) right;
if( dright.getOp()==DataGenMethod.RAND && dright.hasConstantValue() )
{
Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1);
HopRewriteUtils.cleanupUnreferenced(dright);
LOG.debug("Applied removeUnnecessaryVectorizeOperation1");
}
}
//check and remove left vectorized scalar
else if( right.getDataType() == DataType.MATRIX && left instanceof DataGenOp )
{
DataGenOp dleft = (DataGenOp) left;
if( dleft.getOp()==DataGenMethod.RAND && dleft.hasConstantValue()
&& (left.getDim2()==1 || right.getDim2()>1)
&& (left.getDim1()==1 || right.getDim1()>1))
{
Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0);
HopRewriteUtils.cleanupUnreferenced(dleft);
LOG.debug("Applied removeUnnecessaryVectorizeOperation2");
}
}
//Note: we applied this rewrite to at most one side in order to keep the
//output semantically equivalent. However, future extensions might consider
//to remove vectors from both side, compute the binary op on scalars and
//finally feed it into a datagenop of the original dimensions.
}
}
return hi;
}
/**
* handle removal of unnecessary binary operations
*
* X/1 or X*1 or 1*X or X-0 -> X
* -1*X or X*-1-> -X
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos )
throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//X/1 or X*1 -> X
if( left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==1.0 )
{
if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT )
{
HopRewriteUtils.replaceChildReference(parent, bop, left, pos);
hi = left;
LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")");
}
}
//X-0 -> X
else if( left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==0.0 )
{
if( bop.getOp()==OpOp2.MINUS )
{
HopRewriteUtils.replaceChildReference(parent, bop, left, pos);
hi = left;
LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line "+bop.getBeginLine()+")");
}
}
//1*X -> X
else if( right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==1.0 )
{
if( bop.getOp()==OpOp2.MULT )
{
HopRewriteUtils.replaceChildReference(parent, bop, right, pos);
hi = right;
LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line "+bop.getBeginLine()+")");
}
}
//-1*X -> -X
//note: this rewrite is necessary since the new antlr parser always converts
//-X to -1*X due to mechanical reasons
else if( right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==-1.0 )
{
if( bop.getOp()==OpOp2.MULT )
{
bop.setOp(OpOp2.MINUS);
HopRewriteUtils.replaceChildReference(bop, left, new LiteralOp(0), 0);
hi = bop;
LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line "+bop.getBeginLine()+")");
}
}
//X*-1 -> -X (see comment above)
else if( left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==-1.0 )
{
if( bop.getOp()==OpOp2.MULT )
{
bop.setOp(OpOp2.MINUS);
HopRewriteUtils.removeChildReferenceByPos(bop, right, 1);
HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0);
hi = bop;
LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line "+bop.getBeginLine()+")");
}
}
}
return hi;
}
/**
* Handle removal of unnecessary binary operations over rand data
*
* rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7))
* 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7);
*
* @param hi high-order operaton
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop fuseDatagenAndBinaryOperation( Hop hi )
throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//NOTE: rewrite not applied if more than one datagen consumer because this would lead to
//the creation of multiple datagen ops and thus potentially different results if seed not specified)
//left input rand and hence output matrix double, right scalar literal
if( left instanceof DataGenOp && ((DataGenOp)left).getOp()==DataGenMethod.RAND &&
right instanceof LiteralOp && left.getParent().size()==1 )
{
DataGenOp inputGen = (DataGenOp)left;
HashMap<String,Integer> params = inputGen.getParamIndexMap();
Hop pdf = left.getInput().get(params.get(DataExpression.RAND_PDF));
Hop min = left.getInput().get(params.get(DataExpression.RAND_MIN));
Hop max = left.getInput().get(params.get(DataExpression.RAND_MAX));
double sval = ((LiteralOp)right).getDoubleValue();
if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS || bop.getOp() == OpOp2.MINUS)
&& min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp
&& DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
{
//create fused data gen operator
DataGenOp gen = null;
if( bop.getOp()==OpOp2.MULT )
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
else { //OpOp2.PLUS | OpOp2.MINUS
sval *= (bop.getOp()==OpOp2.MINUS) ? -1 : 1;
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
}
//rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<Hop>(bop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation1 (line "+bop.getBeginLine()+").");
}
}
//right input rand and hence output matrix double, left scalar literal
else if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND &&
left instanceof LiteralOp && right.getParent().size()==1 )
{
DataGenOp inputGen = (DataGenOp)right;
HashMap<String,Integer> params = inputGen.getParamIndexMap();
Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
Hop min = right.getInput().get(params.get(DataExpression.RAND_MIN));
Hop max = right.getInput().get(params.get(DataExpression.RAND_MAX));
double sval = ((LiteralOp)left).getDoubleValue();
if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS)
&& min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp
&& DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
{
//create fused data gen operator
DataGenOp gen = null;
if( bop.getOp()==OpOp2.MULT )
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
else { //OpOp2.PLUS
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
}
//rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<Hop>(bop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+").");
}
}
}
return hi;
}
private Hop fuseDatagenAndMinusOperation( Hop hi )
throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND &&
left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==0.0 )
{
DataGenOp inputGen = (DataGenOp)right;
HashMap<String,Integer> params = inputGen.getParamIndexMap();
Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
int ixMin = params.get(DataExpression.RAND_MIN);
int ixMax = params.get(DataExpression.RAND_MAX);
Hop min = right.getInput().get(ixMin);
Hop max = right.getInput().get(ixMax);
//apply rewrite under additional conditions (for simplicity)
if( inputGen.getParent().size()==1
&& min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp
&& DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
{
//exchange and *-1 (special case 0 stays 0 instead of -0 for consistency)
double newMinVal = (((LiteralOp)max).getDoubleValue()==0)?0:(-1 * ((LiteralOp)max).getDoubleValue());
double newMaxVal = (((LiteralOp)min).getDoubleValue()==0)?0:(-1 * ((LiteralOp)min).getDoubleValue());
Hop newMin = new LiteralOp(newMinVal);
Hop newMax = new LiteralOp(newMaxVal);
HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin);
HopRewriteUtils.addChildReference(inputGen, newMin, ixMin);
HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax);
HopRewriteUtils.addChildReference(inputGen, newMax, ixMax);
//rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<Hop>(bop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, bop, inputGen);
hi = inputGen;
LOG.debug("Applied fuseDatagenAndMinusOperation (line "+bop.getBeginLine()+").");
}
}
}
return hi;
}
/**
* Handle simplification of binary operations (relies on previous common subexpression elimination).
* At the same time this servers as a canonicalization for more complex rewrites.
*
* X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos )
throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
//patterns: X+X -> X*2, X*X -> X^2,
if( left == right && left.getDataType()==DataType.MATRIX )
{
//note: we simplify this to unary operations first (less mem and better MR plan),
//however, we later compile specific LOPS for X*2 and X^2
if( bop.getOp()==OpOp2.PLUS ) //X+X -> X*2
{
bop.setOp(OpOp2.MULT);
LiteralOp tmp = new LiteralOp(2);
bop.getInput().remove(1);
right.getParent().remove(bop);
HopRewriteUtils.addChildReference(hi, tmp, 1);
LOG.debug("Applied simplifyBinaryToUnaryOperation1");
}
else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2
{
bop.setOp(OpOp2.POW);
LiteralOp tmp = new LiteralOp(2);
bop.getInput().remove(1);
right.getParent().remove(bop);
HopRewriteUtils.addChildReference(hi, tmp, 1);
LOG.debug("Applied simplifyBinaryToUnaryOperation2");
}
}
//patterns: (X>0)-(X<0) -> sign(X)
else if( bop.getOp() == OpOp2.MINUS
&& HopRewriteUtils.isBinary(left, OpOp2.GREATER)
&& HopRewriteUtils.isBinary(right, OpOp2.LESS)
&& left.getInput().get(0) == right.getInput().get(0)
&& left.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0
&& right.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 )
{
UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
HopRewriteUtils.replaceChildReference(parent, hi, uop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left, right);
hi = uop;
LOG.debug("Applied simplifyBinaryToUnaryOperation3");
}
}
return hi;
}
/**
* Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
* U%*%V-eps into the common representation U%*%V+s which simplifies
* subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
*
* @param hi high-level operator
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop canonicalizeMatrixMultScalarAdd( Hop hi )
throws HopsException
{
//pattern: binary operation (+ or -) of matrix mult and scalar
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
//pattern: (eps + U%*%V) -> (U%*%V+eps)
if( left.getDataType().isScalar() && right instanceof AggBinaryOp
&& bop.getOp()==OpOp2.PLUS )
{
HopRewriteUtils.removeAllChildReferences(bop);
HopRewriteUtils.addChildReference(bop, right, 0);
HopRewriteUtils.addChildReference(bop, left, 1);
LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line "+hi.getBeginLine()+").");
}
//pattern: (U%*%V - eps) -> (U%*%V + (-eps))
else if( right.getDataType().isScalar() && left instanceof AggBinaryOp
&& bop.getOp() == OpOp2.MINUS )
{
bop.setOp(OpOp2.PLUS);
HopRewriteUtils.replaceChildReference(bop, right,
HopRewriteUtils.createBinaryMinus(right), 1);
LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+").");
}
}
return hi;
}
/**
* NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static
* rewrite in order to apply it before splitting dags which would hide the table information
* if dimensions are not specified.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop simplifyReverseOperation( Hop parent, Hop hi, int pos )
throws HopsException
{
if( hi instanceof AggBinaryOp
&& hi.getInput().get(0) instanceof TernaryOp )
{
TernaryOp top = (TernaryOp) hi.getInput().get(0);
if( top.getOp()==OpOp3.CTABLE
&& HopRewriteUtils.isBasic1NSequence(top.getInput().get(0))
&& HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1))
&& top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1())
{
ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
HopRewriteUtils.replaceChildReference(parent, hi, rop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, top);
hi = rop;
LOG.debug("Applied simplifyReverseOperation.");
}
}
return hi;
}
private Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
{
//pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
&& hi.getDataType() == DataType.MATRIX
&& hi.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1
&& HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
&& hi.getInput().get(1).getParent().size() == 1 ) //single consumer
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
//set new binaryop type and rewire inputs
bop.setOp(OpOp2.MINUS1_MULT);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.addChildReference(bop, left);
HopRewriteUtils.addChildReference(bop, right);
LOG.debug("Applied simplifyMultiBinaryToBinaryOperation.");
}
return hi;
}
/**
* (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
* (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
*
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
//(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
boolean applied = false;
if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) )
{
Hop X = null; Hop Y = null;
if( HopRewriteUtils.isBinary(left, OpOp2.MULT) ) //(Y*X-X) -> (Y-1)*X
{
Hop leftC1 = left.getInput().get(0);
Hop leftC2 = left.getInput().get(1);
if( leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX &&
(right == leftC1 || right == leftC2) && leftC1 !=leftC2 ){ //any mult order
X = right;
Y = ( right == leftC1 ) ? leftC2 : leftC1;
}
if( X != null ){ //rewrite 'binary +/-'
LiteralOp literal = new LiteralOp(1);
BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp());
BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left);
hi = mult;
applied = true;
LOG.debug("Applied simplifyDistributiveBinaryOperation1");
}
}
if( !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X
{
Hop rightC1 = right.getInput().get(0);
Hop rightC2 = right.getInput().get(1);
if( rightC1.getDataType()==DataType.MATRIX && rightC2.getDataType()==DataType.MATRIX &&
(left == rightC1 || left == rightC2) && rightC1 !=rightC2 ){ //any mult order
X = left;
Y = ( left == rightC1 ) ? rightC2 : rightC1;
}
if( X != null ){ //rewrite '+/- binary'
LiteralOp literal = new LiteralOp(1);
BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp());
BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, right);
hi = mult;
LOG.debug("Applied simplifyDistributiveBinaryOperation2");
}
}
}
}
return hi;
}
/**
* (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
* (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
*
* Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
* eagerly, which would loose additional rewrite potential. This rewrite has two goals
* (1) enable XtwXv, and increase piggybacking potential by creating bushy trees.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof BinaryOp && parent instanceof AggBinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
OpOp2 op = bop.getOp();
if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX &&
HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY) )
{
boolean applied = false;
if( right instanceof BinaryOp )
{
BinaryOp bop2 = (BinaryOp)right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
OpOp2 op2 = bop2.getOp();
if( op==op2 && right2.getDataType()==DataType.MATRIX
&& (right2 instanceof AggBinaryOp) )
{
//(X*(Y*op()) -> (X*Y)*op()
BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op);
BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op);
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = bop4;
applied = true;
LOG.debug("Applied simplifyBushyBinaryOperation1");
}
}
if( !applied && left instanceof BinaryOp )
{
BinaryOp bop2 = (BinaryOp)left;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
OpOp2 op2 = bop2.getOp();
if( op==op2 && left2.getDataType()==DataType.MATRIX
&& (left2 instanceof AggBinaryOp)
&& (right2.getDim2() > 1 || right.getDim2() == 1) //X not vector, or Y vector
&& (right2.getDim1() > 1 || right.getDim1() == 1) ) //X not vector, or Y vector
{
//((op()*X)*Y) -> op()*(X*Y)
BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op);
BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op);
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = bop4;
LOG.debug("Applied simplifyBushyBinaryOperation2");
}
}
}
}
return hi;
}
private Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg
&& hi.getInput().get(0) instanceof ReorgOp ) //reorg operation
{
ReorgOp rop = (ReorgOp)hi.getInput().get(0);
if( (rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE
|| rop.getOp() == ReOrgOp.REV ) //valid reorg
&& rop.getParent().size()==1 ) //uagg only reorg consumer
{
Hop input = rop.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(rop);
HopRewriteUtils.addChildReference(hi, input);
LOG.debug("Applied simplifyUnaryAggReorgOperation");
}
}
return hi;
}
private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos )
throws HopsException
{
if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
&& hi.getInput().get(0) instanceof BinaryOp )
{
BinaryOp bin = (BinaryOp) hi.getInput().get(0);
BinaryOp bout = null;
//as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
if( bin.getInput().get(0).getDataType()==DataType.MATRIX
&& bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp());
}
//as.scalar(X*s) -> as.scalar(X) * s
else if( bin.getInput().get(0).getDataType()==DataType.MATRIX ) {
UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp());
}
//as.scalar(s*X) -> s * as.scalar(X)
else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
}
if( bout != null ) {
HopRewriteUtils.replaceChildReference(parent, hi, bout, pos);
LOG.debug("Applied simplifyBinaryMatrixScalarOperation.");
}
}
return hi;
}
private Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof AggUnaryOp && hi.getParent().size()==1
&& (((AggUnaryOp) hi).getDirection()==Direction.Row || ((AggUnaryOp) hi).getDirection()==Direction.Col)
&& HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1)
&& HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) )
{
AggUnaryOp uagg = (AggUnaryOp) hi;
//get input rewire existing operators (remove inner transpose)
Hop input = uagg.getInput().get(0).getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0));
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
//pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X))
if( uagg.getDirection()==Direction.Row ) {
uagg.setDirection(Direction.Col);
LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+").");
}
//pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X))
else if( uagg.getDirection()==Direction.Col ) {
uagg.setDirection(Direction.Row);
LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line "+hi.getBeginLine()+").");
}
//create outer transpose operation and rewire operators
HopRewriteUtils.addChildReference(uagg, input); uagg.refreshSizeInformation();
Hop trans = HopRewriteUtils.createTranspose(uagg); //incl refresh size
HopRewriteUtils.addChildReference(parent, trans, pos); //by def, same size
hi = trans;
}
return hi;
}
private Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, int pos )
{
// a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
// probed at root node of b in above example
// (with support for left or right scalar operations)
if( HopRewriteUtils.isTransposeOperation(hi, 1)
&& HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0))
&& hi.getInput().get(0).getParent().size()==1)
{
int Xpos = hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1;
Hop X = hi.getInput().get(0).getInput().get(Xpos);
BinaryOp binary = (BinaryOp) hi.getInput().get(0);
if( HopRewriteUtils.containsTransposeOperation(X.getParent())
&& !HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.CENTRALMOMENT, OpOp2.QUANTILE}))
{
//clear existing wiring
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.removeChildReference(hi, binary);
HopRewriteUtils.removeChildReference(binary, X);
//re-wire operators
HopRewriteUtils.addChildReference(parent, binary, pos);
HopRewriteUtils.addChildReference(binary, hi, Xpos);
HopRewriteUtils.addChildReference(hi, X);
//note: common subexpression later eliminated by dedicated rewrite
hi = binary;
LOG.debug("Applied pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+").");
}
}
return hi;
}
private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws HopsException {
//pattern: sum(lamda*X) -> lamda*sum(X)
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp()==Hop.AggOp.SUM // only one parent which is the sum
&& HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1)
&& ((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX)
||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR)))
{
Hop operand1 = hi.getInput().get(0).getInput().get(0);
Hop operand2 = hi.getInput().get(0).getInput().get(1);
//check which operand is the Scalar and which is the matrix
Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2;
Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2;
AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol);
Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
LOG.debug("Applied pushdownSumBinaryMult.");
return bop;
}
return hi;
}
private Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX //unaryop
&& hi.getInput().get(0) instanceof BinaryOp //binaryop - ppred
&& ((BinaryOp)hi.getInput().get(0)).isPPredOperation() )
{
UnaryOp uop = (UnaryOp) hi; //valid unary op
if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN
|| uop.getOp()==OpOp1.SELP || uop.getOp()==OpOp1.CEIL
|| uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND )
{
//clear link unary-binary
Hop input = uop.getInput().get(0);
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied simplifyUnaryPPredOperation.");
}
}
return hi;
}
private Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos )
{
//e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B)
if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted
&& hi.getInput().get(0) instanceof BinaryOp
&& (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind)
|| ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND)
&& hi.getInput().get(0).getParent().size() == 1 ) //single consumer of append
{
BinaryOp bop = (BinaryOp)hi.getInput().get(0);
//both inputs transpose ops, where transpose is single consumer
if( HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1)
&& HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) )
{
Hop left = bop.getInput().get(0).getInput().get(0);
Hop right = bop.getInput().get(1).getInput().get(0);
//create new subdag (no in-place dag update to prevent anomalies with
//multiple consumers during rewrite process)
OpOp2 binop = (bop.getOp()==OpOp2.CBIND) ? OpOp2.RBIND : OpOp2.CBIND;
BinaryOp bopnew = HopRewriteUtils.createBinary(left, right, binop);
HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos);
hi = bopnew;
LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+").");
}
}
return hi;
}
/**
* handle simplification of more complex sub DAG to unary operation.
*
* X*(1-X) -> sprop(X)
* (1-X)*X -> sprop(X)
* 1/(1+exp(-X)) -> sigmoid(X)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @throws HopsException if HopsException occurs
*/
private Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos )
throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
boolean applied = false;
//sample proportion (sprop) operator
if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
//by definition, either left or right or none applies.
//note: if there are multiple consumers on the intermediate,
//we follow the heuristic that redundant computation is more beneficial,
//i.e., we still fuse but leave the intermediate for the other consumers
if( left instanceof BinaryOp ) //(1-X)*X
{
BinaryOp bleft = (BinaryOp)left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if( left1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 &&
left2 == right && bleft.getOp() == OpOp2.MINUS )
{
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
}
}
if( !applied && right instanceof BinaryOp ) //X*(1-X)
{
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if( right1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 &&
right2 == left && bright.getOp() == OpOp2.MINUS )
{
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
}
}
}
//sigmoid operator
if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp)
{
//note: if there are multiple consumers on the intermediate,
//we follow the heuristic that redundant computation is more beneficial,
//i.e., we still fuse but leave the intermediate for the other consumers
BinaryOp bop2 = (BinaryOp)right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
if( bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX
&& left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp)
{
UnaryOp uop = (UnaryOp) right2;
Hop uopin = uop.getInput().get(0);
if( uop.getOp()==OpOp1.EXP )
{
UnaryOp unary = null;
//Pattern 1: (1/(1 + exp(-X))
if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) {
BinaryOp bop3 = (BinaryOp) uopin;
Hop left3 = bop3.getInput().get(0);
Hop right3 = bop3.getInput().get(1);
if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 )
unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
}
//Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
//the 'remove unnecessary minus' rewrite --> reintroduce the minus
else {
BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
}
if( unary != null ) {
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
}
}
}
}
//select positive (selp) operator (note: same initial pattern as sprop)
if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
//by definition, either left or right or none applies.
//note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial
//to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
if( left instanceof BinaryOp ) //(X>0)*X
{
BinaryOp bleft = (BinaryOp)left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if( left2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
left1 == right && bleft.getOp() == OpOp2.GREATER )
{
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp1");
}
}
if( !applied && right instanceof BinaryOp ) //X*(X>0)
{
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if( right2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
right1 == left && bright.getOp() == OpOp2.GREATER )
{
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied= true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp2");
}
}
}
//select positive (selp) operator; pattern: max(X,0) -> selp+
if( !applied && bop.getOp() == OpOp2.MAX && left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 )
{
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3");
}
//select positive (selp) operator; pattern: max(0,X) -> selp+
if( !applied && bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 )
{
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4");
}
}
return hi;
}
private Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos)
{
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.TRACE ) //trace()
{
Hop hi2 = hi.getInput().get(0);
if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y
{
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
//create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp sum = HopRewriteUtils.createSum(mult);
//rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = sum;
LOG.debug("Applied simplifyTraceMatrixMult");
}
}
return hi;
}
private Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
if( hi instanceof IndexingOp
&& ((IndexingOp)hi).isRowLowerEqualsUpper()
&& ((IndexingOp)hi).isColLowerEqualsUpper()
&& hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer
&& HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) )
{
Hop mm = hi.getInput().get(0);
Hop X = mm.getInput().get(0);
Hop Y = mm.getInput().get(1);
Hop rowExpr = hi.getInput().get(1); //rl==ru
Hop colExpr = hi.getInput().get(3); //cl==cu
HopRewriteUtils.removeAllChildReferences(mm);
//create new indexing operations
IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X,
rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false);
ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
ix1.refreshSizeInformation();
IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y,
new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock());
ix2.refreshSizeInformation();
//rewire matrix mult over ix1 and ix2
HopRewriteUtils.addChildReference(mm, ix1, 0);
HopRewriteUtils.addChildReference(mm, ix2, 1);
mm.refreshSizeInformation();
hi = mm;
LOG.debug("Applied simplifySlicedMatrixMult");
}
return hi;
}
private Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
throws HopsException
{
//order(matrix(7), indexreturn=FALSE) -> matrix(7)
//order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1)
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order
{
Hop hi2 = hi.getInput().get(0);
if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==DataGenMethod.RAND
&& ((DataGenOp)hi2).hasConstantValue()
&& hi.getInput().get(3) instanceof LiteralOp ) //known indexreturn
{
if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) )
{
//order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1)
Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2);
seq.refreshSizeInformation();
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = seq;
LOG.debug("Applied simplifyConstantSort1.");
}
else
{
//order(matrix(7), indexreturn=FALSE) -> matrix(7)
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hi2;
LOG.debug("Applied simplifyConstantSort2.");
}
}
}
return hi;
}
private Hop simplifyOrderedSort(Hop parent, Hop hi, int pos)
throws HopsException
{
//order(seq(2,N+1,1), indexreturn=FALSE) -> matrix(7)
//order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order
{
Hop hi2 = hi.getInput().get(0);
if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==DataGenMethod.SEQ )
{
Hop incr = hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR));
//check for known ascending ordering and known indexreturn
if( incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1
&& hi.getInput().get(2) instanceof LiteralOp //decreasing
&& hi.getInput().get(3) instanceof LiteralOp ) //indexreturn
{
if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET, ASC/DESC
{
//order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2));
Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc);
seq.refreshSizeInformation();
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = seq;
LOG.debug("Applied simplifyOrderedSort1.");
}
else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC
{
//order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1)
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hi2;
LOG.debug("Applied simplifyOrderedSort2.");
}
}
}
}
return hi;
}
/**
* Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos)
throws HopsException
{
if( HopRewriteUtils.isTransposeOperation(hi)
&& hi.getInput().get(0) instanceof BinaryOp //basic binary
&& ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
{
Hop left = hi.getInput().get(0).getInput().get(0);
Hop C = hi.getInput().get(0).getInput().get(1);
//check matrix mult and both inputs transposes w/ single consumer
if( left instanceof AggBinaryOp && C.getDataType().isMatrix()
&& HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
&& left.getInput().get(0).getParent().size()==1
&& HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
&& left.getInput().get(1).getParent().size()==1 )
{
Hop A = left.getInput().get(0).getInput().get(0);
Hop B = left.getInput().get(1).getInput().get(0);
AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A);
ReorgOp rop = HopRewriteUtils.createTranspose(C);
BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
}
}
return hi;
}
/**
* Pattners: t(t(X)) -> X, rev(rev(X)) -> X
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
{
ReOrgOp[] lookup = new ReOrgOp[]{ReOrgOp.TRANSPOSE, ReOrgOp.REV};
if( hi instanceof ReorgOp && HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg
{
ReOrgOp firstOp = ((ReorgOp)hi).getOp();
Hop hi2 = hi.getInput().get(0);
if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==firstOp ) //second reorg w/ same type
{
Hop hi3 = hi2.getInput().get(0);
//remove unnecessary chain of t(t())
HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = hi3;
LOG.debug("Applied removeUnecessaryReorgOperation.");
}
}
return hi;
}
private Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp
&& ((BinaryOp)hi).getOp()==OpOp2.MINUS //first minus
&& hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 )
{
Hop hi2 = hi.getInput().get(1);
if( hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp
&& ((BinaryOp)hi2).getOp()==OpOp2.MINUS //second minus
&& hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 )
{
Hop hi3 = hi2.getInput().get(1);
//remove unnecessary chain of -(-())
HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = hi3;
LOG.debug("Applied removeUnecessaryMinus");
}
}
return hi;
}
private Hop simplifyGroupedAggregate(Hop hi)
{
if( hi instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hi).getOp()==ParamBuiltinOp.GROUPEDAGG ) //aggregate
{
ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi;
if( phi.isCountFunction() //aggregate(fn="count")
&& phi.getTargetHop().getDim2()==1 ) //only for vector
{
HashMap<String, Integer> params = phi.getParamIndexMap();
int ix1 = params.get(Statement.GAGG_TARGET);
int ix2 = params.get(Statement.GAGG_GROUPS);
//check for unnecessary memory consumption for "count"
if( ix1 != ix2 && phi.getInput().get(ix1)!=phi.getInput().get(ix2) )
{
Hop th = phi.getInput().get(ix1);
Hop gh = phi.getInput().get(ix2);
HopRewriteUtils.replaceChildReference(hi, th, gh, ix1);
LOG.debug("Applied simplifyGroupedAggregateCount");
}
}
}
return hi;
}
private Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
//pattern X - (s * ppred(X,0,!=)) -> X -nz s
//note: this is done as a hop rewrite in order to significantly reduce the
//memory estimate for X - tmp if X is sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
&& hi.getInput().get(0).getDataType()==DataType.MATRIX
&& hi.getInput().get(1).getDataType()==DataType.MATRIX
&& HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) )
{
Hop X = hi.getInput().get(0);
Hop s = hi.getInput().get(1).getInput().get(0);
Hop pred = hi.getInput().get(1).getInput().get(1);
if( s.getDataType()==DataType.SCALAR && pred.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
&& pred.getInput().get(0) == X //depend on common subexpression elimination
&& pred.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
{
Hop hnew = HopRewriteUtils.createBinary(X, s, OpOp2.MINUS_NZ);
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")");
}
}
return hi;
}
private Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
//pattern ppred(X,0,"!=")*log(X) -> log_nz(X)
//note: this is done as a hop rewrite in order to significantly reduce the
//memory estimate and to prevent dense intermediates if X is ultra sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
&& hi.getInput().get(0).getDataType()==DataType.MATRIX
&& hi.getInput().get(1).getDataType()==DataType.MATRIX
&& HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) )
{
Hop pred = hi.getInput().get(0);
Hop X = hi.getInput().get(1).getInput().get(0);
if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
&& pred.getInput().get(0) == X //depend on common subexpression elimination
&& pred.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
{
Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ);
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied fuseLogNzUnaryOperation (line "+hi.getBeginLine()+").");
}
}
return hi;
}
private Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
//pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
//note: this is done as a hop rewrite in order to significantly reduce the
//memory estimate and to prevent dense intermediates if X is ultra sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
&& hi.getInput().get(0).getDataType()==DataType.MATRIX
&& hi.getInput().get(1).getDataType()==DataType.MATRIX
&& HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) )
{
Hop pred = hi.getInput().get(0);
Hop X = hi.getInput().get(1).getInput().get(0);
Hop log = hi.getInput().get(1).getInput().get(1);
if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
&& pred.getInput().get(0) == X //depend on common subexpression elimination
&& pred.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
{
Hop hnew = HopRewriteUtils.createBinary(X, log, OpOp2.LOG_NZ);
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")");
}
}
return hi;
}
private Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos)
throws HopsException
{
//pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
//note: this rewrite supports both left/right sequence
if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuterVectorOperator() )
{
if( ( HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: outer(v, t(seq(1,m)), "==")
&& HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0)))
|| HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==")
{
//determine variable parameters for pattern a/b
boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0));
boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));
Hop trgt = isPatternB ? (isTransposeRight ?
hi.getInput().get(1).getInput().get(0) : //get v from t(v)
HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v')
hi.getInput().get(0); //get v directly
Hop seq = isPatternB ?
hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0);
String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols";
//setup input parameter hops
HashMap<String,Hop> inputargs = new HashMap<String,Hop>();
inputargs.put("target", trgt);
inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMaxLiteral(seq));
inputargs.put("dir", new LiteralOp(direction));
inputargs.put("ignore", new LiteralOp(true));
inputargs.put("cast", new LiteralOp(false));
//create new hop
ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
ParamBuiltinOp.REXPAND, inputargs);
pbop.setOutputBlocksizes(hi.getRowsInBlock(), hi.getColsInBlock());
pbop.refreshSizeInformation();
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
hi = pbop;
LOG.debug("Applied simplifyOuterSeqExpand (line "+hi.getBeginLine()+")");
}
}
return hi;
}
/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
@SuppressWarnings("unused")
private Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
Hop datagen = null;
//ppred(X,X,"==") -> matrix(1, rows=nrow(X),cols=nrow(Y))
if( left==right && bop.getOp()==OpOp2.EQUAL || bop.getOp()==OpOp2.GREATEREQUAL || bop.getOp()==OpOp2.LESSEQUAL )
datagen = HopRewriteUtils.createDataGenOp(left, 1);
//ppred(X,X,"!=") -> matrix(0, rows=nrow(X),cols=nrow(Y))
if( left==right && bop.getOp()==OpOp2.NOTEQUAL || bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS )
datagen = HopRewriteUtils.createDataGenOp(left, 0);
if( datagen != null ) {
HopRewriteUtils.replaceChildReference(parent, hi, datagen, pos);
hi = datagen;
}
}
return hi;
}
}