/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.hops.rewrite;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import com.ibm.bi.dml.hops.AggBinaryOp;
import com.ibm.bi.dml.hops.AggUnaryOp;
import com.ibm.bi.dml.hops.BinaryOp;
import com.ibm.bi.dml.hops.DataGenOp;
import com.ibm.bi.dml.hops.Hop;
import com.ibm.bi.dml.hops.Hop.AggOp;
import com.ibm.bi.dml.hops.Hop.DataGenMethod;
import com.ibm.bi.dml.hops.Hop.Direction;
import com.ibm.bi.dml.hops.Hop.OpOp1;
import com.ibm.bi.dml.hops.Hop.ReOrgOp;
import com.ibm.bi.dml.hops.HopsException;
import com.ibm.bi.dml.hops.IndexingOp;
import com.ibm.bi.dml.hops.LeftIndexingOp;
import com.ibm.bi.dml.hops.LiteralOp;
import com.ibm.bi.dml.hops.Hop.OpOp2;
import com.ibm.bi.dml.hops.ReorgOp;
import com.ibm.bi.dml.hops.UnaryOp;
import com.ibm.bi.dml.parser.DMLTranslator;
import com.ibm.bi.dml.parser.DataExpression;
import com.ibm.bi.dml.parser.Expression.DataType;
import com.ibm.bi.dml.parser.Expression.ValueType;
/**
* Rule: Algebraic Simplifications. Simplifies binary expressions
* in terms of two major purposes: (1) rewrite binary operations
* to unary operations when possible (in CP this reduces the memory
* estimate, in MR this allows map-only operations and hence prevents
* unnecessary shuffle and sort) and (2) remove binary operations that
* are in itself are unnecessary (e.g., *1 and /1).
*
*/
public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationDynamic.class.getName());
//valid aggregation operation types for rowOp to Op conversions (not all operations apply)
private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
//valid aggregation operation types for empty (sparse-safe) operations (not all operations apply)
//AggOp.MEAN currently not due to missing count/corrections
private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
//valid unary operation types for empty (sparse-safe) operations (not all operations apply)
private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM};
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return roots;
//one pass rewrite-descend (rewrite created pattern)
for( Hop h : roots )
rule_AlgebraicSimplification( h, false );
Hop.resetVisitStatus(roots);
//one pass descend-rewrite (for rollup)
for( Hop h : roots )
rule_AlgebraicSimplification( h, true );
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return root;
//one pass rewrite-descend (rewrite created pattern)
rule_AlgebraicSimplification( root, false );
root.resetVisitStatus();
//one pass descend-rewrite (for rollup)
rule_AlgebraicSimplification( root, true );
return root;
}
/**
* Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however,
* (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should
* come before constant folding while the other simplifications should come after constant
* folding. Hence, not applied yet.
*
* @throws HopsException
*/
private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
throws HopsException
{
if(hop.getVisited() == Hop.VisitStatus.DONE)
return;
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
Hop hi = hop.getInput().get(i);
//process childs recursively first (to allow roll-up)
if( descendFirst )
rule_AlgebraicSimplification(hi, descendFirst); //see below
//apply actual simplification rewrites (of childs incl checks)
hi = removeEmptyRightIndexing(hop, hi, i); //e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0
hi = removeUnnecessaryRightIndexing(hop, hi, i); //e.g., X[,1] -> X, if output == input size
hi = removeEmptyLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0
hi = removeUnnecessaryLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> Y, if output == input dims
hi = fuseLeftIndexingChainToAppend(hop, hi, i); //e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix
hi = removeUnnecessaryCumulativeOp(hop, hi, i); //e.g., cumsum(X) -> X, if nrow(X)==1;
hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if output == input dims
hi = removeUnnecessaryOuterProduct(hop, hi, i); //e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector
hi = fuseDatagenAndReorgOperation(hop, hi, i); //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colsums(X) -> sum(X) or X, if col/row vector
hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowsums(X) -> sum(X) or X, if row/col vector
hi = simplifyColSumsMVMult(hop, hi, i); //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
hi = simplifyRowSumsMVMult(hop, hi, i); //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
hi = simplifyEmptyAggregate(hop, hi, i); //e.g., sum(X) -> 0, if nnz(X)==0
hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
hi = simplifyEmptyReorgOperation(hop, hi, i); //e.g., t(X) -> matrix(0, ncol(X), nrow(X))
hi = simplifyEmptySortOperation(hop, hi, i); //e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0
hi = simplifyEmptyMatrixMult(hop, hi, i); //e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1)
hi = simplifyIdentityRepMatrixMult(hop, hi, i); //e.g., X%*%y -> X if y matrix(1,1,1);
hi = simplifyScalarMatrixMult(hop, hi, i); //e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyMatrixMultDiag(hop, hi, i); //e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1
hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector
hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector
hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1
hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product
hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
hi = simplifyScalarMVBinaryOperation(hi); //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);
}
hop.setVisited(Hop.VisitStatus.DONE);
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof IndexingOp ) //indexing op
{
Hop input = hi.getInput().get(0);
if( input.getNnz()==0 && //nnz input known and empty
HopRewriteUtils.isDimsKnown(hi)) //output dims known
{
//remove unnecessary right indexing
HopRewriteUtils.removeChildReference(parent, hi);
Hop hnew = HopRewriteUtils.createDataGenOpByVal( new LiteralOp(hi.getDim1()),
new LiteralOp(hi.getDim2()), 0);
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied removeEmptyRightIndexing");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos)
{
if( hi instanceof IndexingOp ) //indexing op
{
Hop input = hi.getInput().get(0);
if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
{
//equal dims of right indexing input and output -> no need for indexing
//remove unnecessary right indexing
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
LOG.debug("Applied removeUnnecessaryRightIndexing");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof LeftIndexingOp ) //left indexing op
{
Hop input1 = hi.getInput().get(0); //lhs matrix
Hop input2 = hi.getInput().get(1); //rhs matrix
if( input1.getNnz()==0 //nnz original known and empty
&& input2.getNnz()==0 ) //nnz input known and empty
{
//remove unnecessary right indexing
HopRewriteUtils.removeChildReference(parent, hi);
Hop hnew = HopRewriteUtils.createDataGenOp( input1, 0);
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied removeEmptyLeftIndexing");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos)
{
if( hi instanceof LeftIndexingOp ) //left indexing op
{
Hop input = hi.getInput().get(1); //rhs matrix
if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
{
//equal dims of left indexing input and output -> no need for indexing
//remove unnecessary right indexing
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
LOG.debug("Applied removeUnnecessaryLeftIndexing");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos)
{
boolean applied = false;
//pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B)
if( hi instanceof LeftIndexingOp //first lix
&& HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi)
&& hi.getInput().get(0) instanceof LeftIndexingOp //second lix
&& HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0))
&& hi.getInput().get(0).getParent().size()==1 //first lix is single consumer
&& hi.getInput().get(0).getInput().get(0).getDim2() == 2 ) //two column matrix
{
Hop input2 = hi.getInput().get(1); //rhs matrix
Hop pred2 = hi.getInput().get(4); //cl=cu
Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix
Hop pred1 = hi.getInput().get(0).getInput().get(4); //cl=cu
if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
&& pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
&& input1.getDataType()==DataType.MATRIX && input2.getDataType()==DataType.MATRIX )
{
//create new cbind operation and rewrite inputs
HopRewriteUtils.removeChildReference(parent, hi);
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND);
HopRewriteUtils.addChildReference(parent, bop, pos);
hi = bop;
applied = true;
}
}
//pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
if( !applied && hi instanceof LeftIndexingOp //first lix
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi)
&& hi.getInput().get(0) instanceof LeftIndexingOp //second lix
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0))
&& hi.getInput().get(0).getParent().size()==1 //first lix is single consumer
&& hi.getInput().get(0).getInput().get(0).getDim1() == 2 ) //two column matrix
{
Hop input2 = hi.getInput().get(1); //rhs matrix
Hop pred2 = hi.getInput().get(2); //rl=ru
Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix
Hop pred1 = hi.getInput().get(0).getInput().get(2); //rl=ru
if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
&& pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
&& input1.getDataType()==DataType.MATRIX && input2.getDataType()==DataType.MATRIX )
{
//create new cbind operation and rewrite inputs
HopRewriteUtils.removeChildReference(parent, hi);
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND);
HopRewriteUtils.addChildReference(parent, bop, pos);
hi = bop;
applied = true;
LOG.debug("Applied fuseLeftIndexingChainToAppend2 (line "+hi.getBeginLine()+")");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)
{
if( hi instanceof UnaryOp && ((UnaryOp)hi).isCumulativeUnaryOperation() )
{
Hop input = hi.getInput().get(0); //input matrix
if( HopRewriteUtils.isDimsKnown(input) //dims input known
&& input.getDim1()==1 ) //1 row
{
OpOp1 op = ((UnaryOp)hi).getOp();
//remove unnecessary unary cumsum operator
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op);
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
{
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp() == ReOrgOp.RESHAPE ) //reshape operation
{
Hop input = hi.getInput().get(0);
if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
{
//equal dims of reshape input and output -> no need for reshape because
//byrow always refers to both input/output and hence gives the same result
//remove unnecessary right indexing
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
LOG.debug("Applied removeUnnecessaryReshape");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos)
{
if( hi instanceof BinaryOp ) //binary cell operation
{
Hop right = hi.getInput().get(1);
//check for column replication
if( right instanceof AggBinaryOp //matrix mult with datagen
&& right.getInput().get(1) instanceof DataGenOp
&& ((DataGenOp)right.getInput().get(1)).hasConstantValue(1d)
&& right.getInput().get(1).getDim1() == 1 //row vector for replication
&& right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary
{
//remove unnecessary outer product
HopRewriteUtils.removeChildReference(hi, right);
HopRewriteUtils.addChildReference(hi, right.getInput().get(0) );
hi.refreshSizeInformation();
//cleanup refs to matrix mult if no remaining consumers
if( right.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( right );
LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")");
}
//check for row replication
else if( right instanceof AggBinaryOp //matrix mult with datagen
&& right.getInput().get(0) instanceof DataGenOp
&& ((DataGenOp)right.getInput().get(0)).hasConstantValue(1d)
&& right.getInput().get(0).getDim2() == 1 //colunm vector for replication
&& right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary
{
//remove unnecessary outer product
HopRewriteUtils.removeChildReference(hi, right);
HopRewriteUtils.addChildReference(hi, right.getInput().get(1) );
hi.refreshSizeInformation();
//cleanup refs to matrix mult if no remaining consumers
if( right.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( right );
LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
@SuppressWarnings("unchecked")
private Hop fuseDatagenAndReorgOperation(Hop parent, Hop hi, int pos)
{
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //transpose
&& hi.getInput().get(0) instanceof DataGenOp //datagen
&& hi.getInput().get(0).getParent().size()==1 ) //transpose only consumer
{
DataGenOp dop = (DataGenOp)hi.getInput().get(0);
if( (dop.getOp() == DataGenMethod.RAND || dop.getOp() == DataGenMethod.SINIT)
&& (dop.getDim1()==1 || dop.getDim2()==1 ))
{
//relink all parents and dataop (remove transpose)
HopRewriteUtils.removeAllChildReferences(hi);
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
for( int i=0; i<parents.size(); i++ ) {
Hop lparent = parents.get(i);
int ppos = HopRewriteUtils.getChildReferencePos(lparent, hi);
HopRewriteUtils.removeChildReferenceByPos(lparent, hi, ppos);
HopRewriteUtils.addChildReference(lparent, dop, pos);
}
//flip rows/cols attributes in datagen
HashMap<String, Integer> rparams = dop.getParamIndexMap();
int pos1 = rparams.get(DataExpression.RAND_ROWS);
int pos2 = rparams.get(DataExpression.RAND_COLS);
rparams.put(DataExpression.RAND_ROWS, pos2);
rparams.put(DataExpression.RAND_COLS, pos1);
dop.refreshSizeInformation();
hi = dop;
LOG.debug("Applied fuseDatagenReorgOperation.");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
@SuppressWarnings("unchecked")
private Hop simplifyColwiseAggregate( Hop parent, Hop hi, int pos )
throws HopsException
{
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) {
if( uhi.getDirection() == Direction.Col )
{
if( input.getDim1() == 1 )
{
//remove unnecessary col aggregation for 1 row
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
LOG.debug("Applied simplifyColwiseAggregate1");
}
else if( input.getDim2() == 1 )
{
//get old parents (before creating cast over aggregate)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
//simplify col-aggregate to full aggregate
uhi.setDirection(Direction.RowCol);
uhi.setDataType(DataType.SCALAR);
//create cast to keep same output datatype
UnaryOp cast = new UnaryOp(uhi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp1.CAST_AS_MATRIX, uhi);
HopRewriteUtils.setOutputParameters(cast, 1, 1, DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize, -1);
//rehang cast under all parents
for( Hop p : parents ) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.removeChildReference(p, hi);
HopRewriteUtils.addChildReference(p, cast, ix);
}
hi = cast;
LOG.debug("Applied simplifyColwiseAggregate2");
}
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
@SuppressWarnings("unchecked")
private Hop simplifyRowwiseAggregate( Hop parent, Hop hi, int pos )
throws HopsException
{
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) {
if( uhi.getDirection() == Direction.Row )
{
if( input.getDim2() == 1 )
{
//remove unnecessary row aggregation for 1 col
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
LOG.debug("Applied simplifyRowwiseAggregate1");
}
else if( input.getDim1() == 1 )
{
//get old parents (before creating cast over aggregate)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
//simplify row-aggregate to full aggregate
uhi.setDirection(Direction.RowCol);
uhi.setDataType(DataType.SCALAR);
//create cast to keep same output datatype
UnaryOp cast = new UnaryOp(uhi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp1.CAST_AS_MATRIX, uhi);
HopRewriteUtils.setOutputParameters(cast, 1, 1, DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize, -1);
//rehang cast under all parents
for( Hop p : parents ) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.removeChildReference(p, hi);
HopRewriteUtils.addChildReference(p, cast, ix);
}
hi = cast;
LOG.debug("Applied simplifyRowwiseAggregate2");
}
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyColSumsMVMult( Hop parent, Hop hi, int pos )
throws HopsException
{
//colSums(X*Y) -> t(Y) %*% X, if Y col vector; additional transpose later
//removed by other rewrite if unnecessary, i.e., if Y==t(Z)
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col ) //colsums
{
if( input instanceof BinaryOp && ((BinaryOp)input).getOp()==OpOp2.MULT ) //b(*)
{
Hop left = input.getInput().get(0);
Hop right = input.getInput().get(1);
if( left.getDim1()>1 && left.getDim2()>1
&& right.getDim1()>1 && right.getDim2()==1 ) // MV (col vector)
{
//remove link parent to rowsums
HopRewriteUtils.removeChildReference(parent, hi);
//create new operators
ReorgOp trans = HopRewriteUtils.createTranspose(right);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
//relink new child
HopRewriteUtils.addChildReference(parent, mmult, pos);
hi = mmult;
//cleanup old dag
if( uhi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences(uhi);
if( input.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences(input);
LOG.debug("Applied simplifyColSumsMVMult");
}
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyRowSumsMVMult( Hop parent, Hop hi, int pos )
throws HopsException
{
//rowSums(X * Y) -> X %*% t(Y), if Y row vector; additional transpose later
//removed by other rewrite if unnecessary, i.e., if Y==t(Z)
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row ) //rowsums
{
if( input instanceof BinaryOp && ((BinaryOp)input).getOp()==OpOp2.MULT ) //b(*)
{
Hop left = input.getInput().get(0);
Hop right = input.getInput().get(1);
if( left.getDim1()>1 && left.getDim2()>1
&& right.getDim1()==1 && right.getDim2()>1 ) // MV (row vector)
{
//remove link parent to rowsums
HopRewriteUtils.removeChildReference(parent, hi);
//create new operators
ReorgOp trans = HopRewriteUtils.createTranspose(right);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
//relink new child
HopRewriteUtils.addChildReference(parent, mmult, pos);
hi = mmult;
//cleanup old dag
if( uhi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences(uhi);
if( input.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences(input);
LOG.debug("Applied simplifyRowSumsMVMult");
}
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_AGGREGATE) ){
if( HopRewriteUtils.isEmpty(input) )
{
//remove unnecessary aggregation
HopRewriteUtils.removeChildReference(parent, hi);
Hop hnew = null;
if( uhi.getDirection() == Direction.RowCol )
hnew = new LiteralOp(0.0);
else if( uhi.getDirection() == Direction.Col )
hnew = HopRewriteUtils.createDataGenOp(uhi, input, 0); //nrow(uhi)=1
else //if( uhi.getDirection() == Direction.Row )
hnew = HopRewriteUtils.createDataGenOp(input, uhi, 0); //ncol(uhi)=1
//add new child to parent input
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied simplifyEmptyAggregate");
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof UnaryOp )
{
UnaryOp uhi = (UnaryOp)hi;
Hop input = uhi.getInput().get(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_UNARY) ){
if( HopRewriteUtils.isEmpty(input) )
{
//remove unnecessary aggregation
HopRewriteUtils.removeChildReference(parent, hi);
//create literal add it to parent
Hop hnew = HopRewriteUtils.createDataGenOp(input, 0);
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied simplifyEmptyUnaryOperation");
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyEmptyReorgOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof ReorgOp )
{
ReorgOp rhi = (ReorgOp)hi;
Hop input = rhi.getInput().get(0);
if( HopRewriteUtils.isEmpty(input) ) //empty input
{
//reorg-operation-specific rewrite
Hop hnew = null;
if( rhi.getOp() == ReOrgOp.TRANSPOSE )
hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0);
else if( rhi.getOp() == ReOrgOp.DIAG ){
if( HopRewriteUtils.isDimsKnown(input) ){
if( input.getDim1()==1 ) //diagv2m
hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0);
else //diagm2v
hnew = HopRewriteUtils.createDataGenOpByVal(
HopRewriteUtils.createValueHop(input,true), new LiteralOp(1), 0);
}
}
else if( rhi.getOp() == ReOrgOp.RESHAPE )
hnew = HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1), rhi.getInput().get(2), 0);
//modify dag if one of the above rules applied
if( hnew != null ){
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied simplifyEmptyReorgOperation");
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
//order(X, indexreturn=FALSE) -> matrix(0,nrow(X),1)
//order(X, indexreturn=TRUE) -> seq(1,nrow(X),1)
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT )
{
ReorgOp rhi = (ReorgOp)hi;
Hop input = rhi.getInput().get(0);
if( HopRewriteUtils.isEmpty(input) ) //empty input
{
//reorg-operation-specific rewrite
Hop hnew = null;
boolean ixret = false;
if( rhi.getInput().get(3) instanceof LiteralOp ) //index return known
{
ixret = HopRewriteUtils.getBooleanValue((LiteralOp)rhi.getInput().get(3));
if( ixret )
hnew = HopRewriteUtils.createSeqDataGenOp(input);
else
hnew = HopRewriteUtils.createDataGenOp(input, 0);
}
//modify dag if one of the above rules applied
if( hnew != null ){
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied simplifyEmptySortOperation (indexreturn="+ixret+").");
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y -> matrix(0, )
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
if( HopRewriteUtils.isEmpty(left) //one input empty
|| HopRewriteUtils.isEmpty(right) )
{
//remove unnecessary matrix mult
HopRewriteUtils.removeChildReference(parent, hi);
//create datagen and add it to parent
Hop hnew = HopRewriteUtils.createDataGenOp(left, right, 0);
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied simplifyEmptyMatrixMult");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyIdentityRepMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y -> X, if y is matrix(1,1,1)
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// X %*% y -> X
if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 && //scalar right
right instanceof DataGenOp && ((DataGenOp)right).hasConstantValue(1.0)) //matrix(1,)
{
HopRewriteUtils.removeChildReference(parent, hi);
HopRewriteUtils.addChildReference(parent, left, pos);
hi = left;
LOG.debug("Applied simplifyIdentiyMatrixMult");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// y %*% X -> as.scalar(y) * X
if( HopRewriteUtils.isDimsKnown(left) && left.getDim1()==1 && left.getDim2()==1 ) //scalar left
{
//remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(parent, hi);
UnaryOp cast = new UnaryOp(left.getName(), DataType.SCALAR, ValueType.DOUBLE,
OpOp1.CAST_AS_SCALAR, left);
HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, 0);
BinaryOp mult = new BinaryOp(cast.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, cast, right);
HopRewriteUtils.setOutputParameters(mult, right.getDim1(), right.getDim2(), right.getRowsInBlock(), right.getColsInBlock(), -1);
//cleanup if only consumer of intermediate
if( hi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi );
//add mult to parent
HopRewriteUtils.addChildReference(parent, mult, pos);
parent.refreshSizeInformation();
hi = mult;
LOG.debug("Applied simplifyScalarMatrixMult1");
}
// X %*% y -> X * as.scalar(y)
else if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right
{
//remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(parent, hi);
UnaryOp cast = new UnaryOp(right.getName(), DataType.SCALAR, ValueType.DOUBLE,
OpOp1.CAST_AS_SCALAR, right);
HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, 0);
BinaryOp mult = new BinaryOp(cast.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, cast, left);
HopRewriteUtils.setOutputParameters(mult, left.getDim1(), left.getDim2(), left.getRowsInBlock(), left.getColsInBlock(), -1);
//cleanup if only consumer of intermediate
if( hi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi );
//add mult to parent
HopRewriteUtils.addChildReference(parent, mult, pos);
parent.refreshSizeInformation();
hi = mult;
LOG.debug("Applied simplifyScalarMatrixMult2");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyMatrixMultDiag(Hop parent, Hop hi, int pos)
throws HopsException
{
Hop hnew = null;
if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// diag(X) %*% Y -> X * Y / diag(X) %*% Y -> Y * X
// previously rep required for the second case: diag(X) %*% Y -> (X%*%ones) * Y
if( left instanceof ReorgOp && ((ReorgOp)left).getOp()==ReOrgOp.DIAG //left diag
&& HopRewriteUtils.isDimsKnown(left) && left.getDim2()>1 ) //diagV2M
{
//System.out.println("diag mm rewrite: dim2(right)="+right.getDim2());
if( right.getDim2()==1 ) //right column vector
{
//remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(parent, hi);
//create binary operation over input and right
Hop input = left.getInput().get(0); //diag input
hnew = new BinaryOp(input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, input, right);
HopRewriteUtils.setOutputParameters(hnew, left.getDim1(), right.getDim2(), left.getRowsInBlock(), left.getColsInBlock(), -1);
LOG.debug("Applied simplifyMatrixMultDiag1");
}
else if( right.getDim2()>1 ) //multi column vector
{
//remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(parent, hi);
//create binary operation over input and right; in contrast to above rewrite,
//we need to switch the order because MV binary cell operations require vector on the right
Hop input = left.getInput().get(0); //diag input
hnew = new BinaryOp(input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, right, input);
HopRewriteUtils.setOutputParameters(hnew, left.getDim1(), right.getDim2(), left.getRowsInBlock(), left.getColsInBlock(), -1);
//NOTE: previously to MV binary cell operations we replicated the left (if moderate number of columns: 2)
//create binary operation over input and right
//Hop input = left.getInput().get(0);
//Hop ones = HopRewriteUtils.createDataGenOpByVal(new LiteralOp("1",1), new LiteralOp(String.valueOf(right.getDim2()),right.getDim2()), 1);
//Hop repmat = new AggBinaryOp( input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, AggOp.SUM, input, ones );
//HopRewriteUtils.setOutputParameters(repmat, input.getDim1(), ones.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), -1);
//hnew = new BinaryOp(input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, repmat, right);
//HopRewriteUtils.setOutputParameters(hnew, right.getDim1(), right.getDim2(), right.getRowsInBlock(), right.getColsInBlock(), -1);
LOG.debug("Applied simplifyMatrixMultDiag2");
}
}
//notes: similar rewrites would be possible for the right side as well, just transposed into the right alignment
}
//if one of the above rewrites applied
if( hnew !=null ){
//cleanup if only consumer of intermediate
if( hi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi );
//add mult to parent
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos)
{
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.DIAG && hi.getDim2()==1 ) //diagM2V
{
Hop hi2 = hi.getInput().get(0);
if( hi2 instanceof AggBinaryOp && ((AggBinaryOp)hi2).isMatrixMultiply() ) //X%*%Y
{
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
//remove link from parent to diag
HopRewriteUtils.removeChildReference(parent, hi);
//remove links to inputs to matrix mult
//removeChildReference(hi2, left);
//removeChildReference(hi2, right);
//create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, left, trans);
mult.setRowsInBlock(right.getRowsInBlock());
mult.setColsInBlock(right.getColsInBlock());
mult.refreshSizeInformation();
AggUnaryOp rowSum = new AggUnaryOp(right.getName(), right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, mult);
rowSum.setRowsInBlock(right.getRowsInBlock());
rowSum.setColsInBlock(right.getColsInBlock());
rowSum.refreshSizeInformation();
//rehang new subdag under parent node
HopRewriteUtils.addChildReference(parent, rowSum, pos);
parent.refreshSizeInformation();
//cleanup if only consumer of intermediate
if( hi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi );
if( hi2.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi2 );
hi = rowSum;
LOG.debug("Applied simplifyDiagMatrixMult");
}
}
return hi;
}
/**
*
* @param hi
*/
private Hop simplifySumDiagToTrace(Hop hi)
{
if( hi instanceof AggUnaryOp )
{
AggUnaryOp au = (AggUnaryOp) hi;
if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum
{
Hop hi2 = au.getInput().get(0);
if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V
{
Hop hi3 = hi2.getInput().get(0);
//remove diag operator
HopRewriteUtils.removeChildReference(au, hi2);
HopRewriteUtils.addChildReference(au, hi3, 0);
//change sum to trace
au.setOp( AggOp.TRACE );
//cleanup if only consumer of intermediate
if( hi2.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi2 );
LOG.debug("Applied simplifySumDiagToTrace");
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
@SuppressWarnings("unchecked")
private Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos)
{
//diag(X)*7 --> diag(X*7) in order to (1) reduce required memory for b(*) and
//(2) in order to make the binary operation more efficient (dense vector vs sparse matrix)
if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MULT )
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
boolean applyLeft = false;
boolean applyRight = false;
//left input is diag
if( left instanceof ReorgOp && ((ReorgOp)left).getOp()==ReOrgOp.DIAG
&& left.getParent().size()==1 //binary op only parent
&& left.getInput().get(0).getDim2()==1 //col vector
&& right.getDataType() == DataType.SCALAR )
{
applyLeft = true;
}
else if( right instanceof ReorgOp && ((ReorgOp)right).getOp()==ReOrgOp.DIAG
&& right.getParent().size()==1 //binary op only parent
&& right.getInput().get(0).getDim2()==1 //col vector
&& left.getDataType() == DataType.SCALAR )
{
applyRight = true;
}
//perform actual rewrite
if( applyLeft || applyRight )
{
//remove all parent links to binary op (since we want to reorder
//we cannot just look at the current parent)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
ArrayList<Integer> parentspos = new ArrayList<Integer>();
for(Hop lparent : parents) {
int lpos = HopRewriteUtils.getChildReferencePos(lparent, hi);
HopRewriteUtils.removeChildReferenceByPos(lparent, hi, lpos);
parentspos.add(lpos);
}
//rewire binop-diag-input into diag-binop-input
if( applyLeft ) {
Hop input = left.getInput().get(0);
HopRewriteUtils.removeChildReferenceByPos(hi, left, 0);
HopRewriteUtils.removeChildReferenceByPos(left, input, 0);
HopRewriteUtils.addChildReference(left, hi, 0);
HopRewriteUtils.addChildReference(hi, input, 0);
hi.refreshSizeInformation();
hi = left;
}
else if ( applyRight ) {
Hop input = right.getInput().get(0);
HopRewriteUtils.removeChildReferenceByPos(hi, right, 1);
HopRewriteUtils.removeChildReferenceByPos(right, input, 0);
HopRewriteUtils.addChildReference(right, hi, 0);
HopRewriteUtils.addChildReference(hi, input, 1);
hi.refreshSizeInformation();
hi = right;
}
//relink all parents to the diag operation
for( int i=0; i<parents.size(); i++ ) {
Hop lparent = parents.get(i);
int lpos = parentspos.get(i);
HopRewriteUtils.addChildReference(lparent, hi, lpos);
}
LOG.debug("Applied pushdownBinaryOperationOnDiag.");
}
}
return hi;
}
/**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
* a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
* beneficial.
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyDotProductSum(Hop parent, Hop hi, int pos)
throws HopsException
{
//sum(v^2)/sum(v1*v2) --> as.scalar(t(v)%*%v) in order to exploit tsmm vector dotproduct
//w/o materialization of intermediates
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full aggregate
&& hi.getInput().get(0).getDim2() == 1 ) //vector (for correctness)
{
Hop baLeft = null;
Hop baRight = null;
Hop hi2 = hi.getInput().get(0); //check for ^2 w/o multiple consumers
//check for sum(v^2), might have been rewritten from sum(v*v)
if( hi2 instanceof BinaryOp && ((BinaryOp)hi2).getOp()==OpOp2.POW
&& hi2.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getIntValue((LiteralOp)hi2.getInput().get(1))==2
&& hi2.getParent().size() == 1 ) //no other consumer than sum
{
Hop input = hi2.getInput().get(0);
baLeft = input;
baRight = input;
}
//check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
else if( hi2 instanceof BinaryOp && ((BinaryOp)hi2).getOp()==OpOp2.MULT
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
&& hi2.getParent().size() == 1 //no other consumer than sum
&& !(hi2.getInput().get(0) instanceof BinaryOp && ((BinaryOp)hi2.getInput().get(0)).getOp()==OpOp2.MULT)
&& !(hi2.getInput().get(1) instanceof BinaryOp && ((BinaryOp)hi2.getInput().get(1)).getOp()==OpOp2.MULT))
{
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
}
//perform actual rewrite (if necessary)
if( baLeft != null && baRight != null )
{
//remove link from parent to diag
HopRewriteUtils.removeChildReference(parent, hi);
//create new operator chain
ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
UnaryOp cast = new UnaryOp(baLeft.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp1.CAST_AS_SCALAR, mmult);
HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, -1);
//rehang new subdag under parent node
HopRewriteUtils.addChildReference(parent, cast, pos);
parent.refreshSizeInformation();
//cleanup if only consumer of intermediate
if( hi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi );
if( hi2.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi2 );
hi = cast;
LOG.debug("Applied simplifyDotProductSum.");
}
}
return hi;
}
/**
* Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
*
* @param parent Parent HOP for which hi is an input.
* @param hi Current HOP for potential rewrite.
* @param pos Position of hi in parent's list of inputs.
*
* @return Either hi or the rewritten HOP replacing it.
*
* @throws HopsException
*/
private Hop fuseSumSquared(Hop parent, Hop hi, int pos)
throws HopsException {
// if SUM
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
Hop sumInput = hi.getInput().get(0);
// if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
if (sumInput instanceof BinaryOp && ((BinaryOp) sumInput).getOp() == OpOp2.POW
&& sumInput.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2
&& sumInput.getParent().size() == 1) {
Hop x = sumInput.getInput().get(0);
// if X is NOT a column vector
if (x.getDim2() > 1) {
// perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
DataType dt = hi.getDataType();
ValueType vt = hi.getValueType();
Direction dir = ((AggUnaryOp) hi).getDirection();
long brlen = hi.getRowsInBlock();
long bclen = hi.getColsInBlock();
AggUnaryOp sumSq = new AggUnaryOp("sumSq", dt, vt, AggOp.SUM_SQ, dir, x);
HopRewriteUtils.setOutputBlocksizes(sumSq, brlen, bclen);
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.addChildReference(parent, sumSq, pos);
// cleanup
if (hi.getParent().isEmpty())
HopRewriteUtils.removeAllChildReferences(hi);
if(sumInput.getParent().isEmpty())
HopRewriteUtils.removeAllChildReferences(sumInput);
// replace current HOP with new SUM_SQ HOP
hi = sumSq;
}
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyEmptyBinaryOperation(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof BinaryOp ) //b(?) X Y
{
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
Hop hnew = null;
//NOTE: these rewrites of binary cell operations need to be aware that right is
//potentially a vector but the result is of the size of left
boolean notBinaryMV = HopRewriteUtils.isNotMatrixVectorBinaryOperation(bop);
switch( bop.getOp() ){
//X * Y -> matrix(0,nrow(X),ncol(X));
case MULT: {
if( HopRewriteUtils.isEmpty(left) ) //empty left and size known
hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
else if( HopRewriteUtils.isEmpty(right) //empty right and right not a vector
&& right.getDim1()>1 && right.getDim2()>1 ) {
hnew = HopRewriteUtils.createDataGenOp(right, right, 0);
}
else if( HopRewriteUtils.isEmpty(right) )//empty right and right potentially a vector
hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
break;
}
case PLUS: {
if( HopRewriteUtils.isEmpty(left) && HopRewriteUtils.isEmpty(right) ) //empty left/right and size known
hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
else if( HopRewriteUtils.isEmpty(left) && notBinaryMV ) //empty left
hnew = right;
else if( HopRewriteUtils.isEmpty(right) ) //empty right
hnew = left;
break;
}
case MINUS: {
if( HopRewriteUtils.isEmpty(left) && notBinaryMV ) { //empty left
HopRewriteUtils.removeChildReference(hi, left);
HopRewriteUtils.addChildReference(hi, new LiteralOp(0), 0);
hnew = hi;
}
else if( HopRewriteUtils.isEmpty(right) ) //empty and size known
hnew = left;
break;
}
default:
hnew = null;
}
if( hnew != null )
{
//remove unnecessary matrix mult
HopRewriteUtils.removeChildReference(parent, hi);
//create datagen and add it to parent
HopRewriteUtils.addChildReference(parent, hnew, pos);
parent.refreshSizeInformation();
hi = hnew;
LOG.debug("Applied simplifyEmptyBinaryOperation");
}
}
}
return hi;
}
/**
* This is rewrite tries to reorder minus operators from inputs of matrix
* multiply to its output because the output is (except for outer products)
* usually significantly smaller. Furthermore, this rewrite is a precondition
* for the important hops-lops rewrite of transpose-matrixmult if the transpose
* is hidden under the minus.
*
* NOTE: in this rewrite we need to modify the links to all parents because we
* remove existing links of subdags and hence affect all consumers.
*
* TODO select up or down based on size
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
@SuppressWarnings("unchecked")
private Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y
{
Hop hileft = hi.getInput().get(0);
Hop hiright = hi.getInput().get(1);
if( hileft instanceof BinaryOp && ((BinaryOp)hileft).getOp()==OpOp2.MINUS //X=-Z
&& hileft.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)hileft.getInput().get(0))==0.0 )
{
Hop hi2 = hileft.getInput().get(1);
//remove link from matrixmult to minus
HopRewriteUtils.removeChildReference(hi, hileft);
//get old parents (before creating minus over matrix mult)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
//create new operators
BinaryOp minus = new BinaryOp(hi.getName(), hi.getDataType(), hi.getValueType(), OpOp2.MINUS, new LiteralOp(0), hi);
minus.setRowsInBlock(hi.getRowsInBlock());
minus.setColsInBlock(hi.getColsInBlock());
//rehang minus under all parents
for( Hop p : parents ) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.removeChildReference(p, hi);
HopRewriteUtils.addChildReference(p, minus, ix);
}
//rehang child of minus under matrix mult
HopRewriteUtils.addChildReference(hi, hi2, 0);
//cleanup if only consumer of minus
if( hileft.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hileft );
hi = minus;
LOG.debug("Applied reorderMinusMatrixMult");
}
else if( hiright instanceof BinaryOp && ((BinaryOp)hiright).getOp()==OpOp2.MINUS //X=-Z
&& hiright.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)hiright.getInput().get(0))==0.0 )
{
Hop hi2 = hiright.getInput().get(1);
//remove link from matrixmult to minus
HopRewriteUtils.removeChildReference(hi, hiright);
//get old parents (before creating minus over matrix mult)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
//create new operators
BinaryOp minus = new BinaryOp(hi.getName(), hi.getDataType(), hi.getValueType(), OpOp2.MINUS, new LiteralOp(0), hi);
minus.setRowsInBlock(hi.getRowsInBlock());
minus.setColsInBlock(hi.getColsInBlock());
//rehang minus under all parents
for( Hop p : parents ) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.removeChildReference(p, hi);
HopRewriteUtils.addChildReference(p, minus, ix);
}
//rehang child of minus under matrix mult
HopRewriteUtils.addChildReference(hi, hi2, 1);
//cleanup if only consumer of minus
if( hiright.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hiright );
hi = minus;
LOG.debug("Applied reorderMinusMatrixMult");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
*/
private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos)
{
//sum(A%*%B) -> sum(t(colSums(A))*rowSums(B))
//if not dot product, not applied since aggregate removed
//if sum not the only consumer, not applied to prevent redundancy
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
&& ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate
&& hi.getInput().get(0) instanceof AggBinaryOp //A%*%B
&& (hi.getInput().get(0).getDim1()>1 || hi.getInput().get(0).getDim2()>1) //not dot product
&& hi.getInput().get(0).getParent().size()==1 ) //not multiple consumers of matrix mult
{
Hop hi2 = hi.getInput().get(0);
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
//remove link from parent to diag
HopRewriteUtils.removeChildReference(hi, hi2);
//create new operators
AggUnaryOp colSum = new AggUnaryOp(left.getName(), left.getDataType(), left.getValueType(), AggOp.SUM, Direction.Col, left);
colSum.setRowsInBlock(left.getRowsInBlock());
colSum.setColsInBlock(left.getColsInBlock());
colSum.refreshSizeInformation();
ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
AggUnaryOp rowSum = new AggUnaryOp(right.getName(), right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, right);
rowSum.setRowsInBlock(right.getRowsInBlock());
rowSum.setColsInBlock(right.getColsInBlock());
rowSum.refreshSizeInformation();
BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, trans, rowSum);
mult.setRowsInBlock(right.getRowsInBlock());
mult.setColsInBlock(right.getColsInBlock());
mult.refreshSizeInformation();
//rehang new subdag under current node (keep hi intact)
HopRewriteUtils.addChildReference(hi, mult, 0);
hi.refreshSizeInformation();
//cleanup if only consumer of intermediate
if( hi2.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi2 );
LOG.debug("Applied simplifySumMatrixMult.");
}
return hi;
}
/**
*
* @param hi
* @return
* @throws HopsException
*/
private Hop simplifyScalarMVBinaryOperation(Hop hi)
throws HopsException
{
if( hi instanceof BinaryOp && ((BinaryOp)hi).supportsMatrixScalarOperations() //e.g., X * s
&& hi.getInput().get(0).getDataType()==DataType.MATRIX
&& hi.getInput().get(1).getDataType()==DataType.MATRIX )
{
Hop right = hi.getInput().get(1);
//X * s -> X * as.scalar(s)
if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right
{
//remove link to right child and introduce cast
HopRewriteUtils.removeChildReference(hi, right);
UnaryOp cast = new UnaryOp(right.getName(), DataType.SCALAR, ValueType.DOUBLE,
OpOp1.CAST_AS_SCALAR, right);
HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, 0);
HopRewriteUtils.addChildReference(hi, cast, 1);
LOG.debug("Applied simplifyScalarMVBinaryOperation.");
}
}
return hi;
}
/**
*
* @param parent
* @param hi
* @param pos
* @return
* @throws HopsException
*/
private Hop simplifyNnzComputation(Hop parent, Hop hi, int pos)
throws HopsException
{
//sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
&& ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate
&& hi.getInput().get(0) instanceof BinaryOp
&& ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.NOTEQUAL )
{
Hop ppred = hi.getInput().get(0);
Hop X = null;
if( ppred.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(0))==0 )
{
X = ppred.getInput().get(1);
}
else if( ppred.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(1))==0 )
{
X = ppred.getInput().get(0);
}
//apply rewrite if known nnz
if( X != null && X.getNnz() > 0 ){
Hop hnew = new LiteralOp(X.getNnz());
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.addChildReference(parent, hnew, pos);
if( hi.getParent().isEmpty() )
HopRewriteUtils.removeAllChildReferences( hi );
hi = hnew;
LOG.debug("Applied simplifyNnzComputation.");
}
}
return hi;
}
}