/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.HashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.QuaternaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.OpOp4; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.lops.MapMultChain.ChainType; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; /** * Rule: Algebraic Simplifications. Simplifies binary expressions * in terms of two major purposes: (1) rewrite binary operations * to unary operations when possible (in CP this reduces the memory * estimate, in MR this allows map-only operations and hence prevents * unnecessary shuffle and sort) and (2) remove binary operations that * are in itself are unnecessary (e.g., *1 and /1). * */ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule { private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationDynamic.class.getName()); //valid aggregation operation types for rowOp to Op conversions (not all operations apply) private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; //valid aggregation operation types for empty (sparse-safe) operations (not all operations apply) //AggOp.MEAN currently not due to missing count/corrections private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE}; private static AggOp[] LOOKUP_VALID_UNNECESSARY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE}; //valid unary operation types for empty (sparse-safe) operations (not all operations apply) private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM}; //valid pseudo-sparse-safe binary operators for wdivmm private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV}; //valid unary and binary operators for wumm private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG, OpOp1.SQRT, OpOp1.SIGMOID, OpOp1.SPROP}; private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.POW}; @Override public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { if( roots == null ) return roots; //one pass rewrite-descend (rewrite created pattern) for( Hop h : roots ) rule_AlgebraicSimplification( h, false ); Hop.resetVisitStatus(roots); //one pass descend-rewrite (for rollup) for( Hop h : roots ) rule_AlgebraicSimplification( h, true ); return roots; } @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { if( root == null ) return root; //one pass rewrite-descend (rewrite created pattern) rule_AlgebraicSimplification( root, false ); root.resetVisitStatus(); //one pass descend-rewrite (for rollup) rule_AlgebraicSimplification( root, true ); return root; } /** * Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however, * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should * come before constant folding while the other simplifications should come after constant * folding. Hence, not applied yet. * * @param hop high-level operator * @param descendFirst true if recursively process children first * @throws HopsException if HopsException occurs */ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) throws HopsException { if(hop.isVisited()) return; //recursively process children for( int i=0; i<hop.getInput().size(); i++) { Hop hi = hop.getInput().get(i); //process childs recursively first (to allow roll-up) if( descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); //see below //apply actual simplification rewrites (of childs incl checks) hi = removeEmptyRightIndexing(hop, hi, i); //e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0 hi = removeUnnecessaryRightIndexing(hop, hi, i); //e.g., X[,1] -> X, if output == input size hi = removeEmptyLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0 hi = removeUnnecessaryLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> Y, if output == input dims if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = fuseLeftIndexingChainToAppend(hop, hi, i); //e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix hi = removeUnnecessaryCumulativeOp(hop, hi, i); //e.g., cumsum(X) -> X, if nrow(X)==1; hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims hi = removeUnnecessaryOuterProduct(hop, hi, i); //e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = fuseDatagenAndReorgOperation(hop, hi, i); //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1 hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colsums(X) -> sum(X) or X, if col/row vector hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowsums(X) -> sum(X) or X, if row/col vector hi = simplifyColSumsMVMult(hop, hi, i); //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector hi = simplifyRowSumsMVMult(hop, hi, i); //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector hi = simplifyUnnecessaryAggregate(hop, hi, i); //e.g., sum(X) -> as.scalar(X), if 1x1 dims hi = simplifyEmptyAggregate(hop, hi, i); //e.g., sum(X) -> 0, if nnz(X)==0 hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 hi = simplifyEmptyReorgOperation(hop, hi, i); //e.g., t(X) -> matrix(0, ncol(X), nrow(X)) hi = simplifyEmptySortOperation(hop, hi, i); //e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0 hi = simplifyEmptyMatrixMult(hop, hi, i); //e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1) hi = simplifyIdentityRepMatrixMult(hop, hi, i); //e.g., X%*%y -> X if y matrix(1,1,1); hi = simplifyScalarMatrixMult(hop, hi, i); //e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix hi = simplifyMatrixMultDiag(hop, hi, i); //e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1 hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { hi = simplifyWeightedSquaredLoss(hop, hi, i); //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true), hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type) hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left) hi = simplifyWeightedCrossEntropy(hop, hi, i); //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V)) hi = simplifyWeightedUnaryMM(hop, hi, i); //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp) hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1 hi = fuseAxpyBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) } hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X hi = simplifyScalarMVBinaryOperation(hi); //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known hi = simplifyNrowNcolComputation(hop, hi, i); //e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); } hop.setVisited(); } private Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof IndexingOp && hi.getDataType()==DataType.MATRIX ) //indexing op { Hop input = hi.getInput().get(0); if( input.getNnz()==0 && //nnz input known and empty HopRewriteUtils.isDimsKnown(hi)) //output dims known { //remove unnecessary right indexing Hop hnew = HopRewriteUtils.createDataGenOpByVal( new LiteralOp(hi.getDim1()), new LiteralOp(hi.getDim2()), 0); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); HopRewriteUtils.cleanupUnreferenced(hi, input); hi = hnew; LOG.debug("Applied removeEmptyRightIndexing"); } } return hi; } private Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos) { if( hi instanceof IndexingOp ) //indexing op { Hop input = hi.getInput().get(0); if( HopRewriteUtils.isEqualSize(hi, input) //equal dims && !(hi.getDim1()==1 && hi.getDim2()==1) ) //not 1-1 matrix/frame { //equal dims of right indexing input and output -> no need for indexing //(not applied for 1-1 matrices because low potential and issues w/ error //handling if out of range indexing) //remove unnecessary right indexing HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied removeUnnecessaryRightIndexing"); } } return hi; } private Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof LeftIndexingOp && hi.getDataType() == DataType.MATRIX ) //left indexing op { Hop input1 = hi.getInput().get(0); //lhs matrix Hop input2 = hi.getInput().get(1); //rhs matrix if( input1.getNnz()==0 //nnz original known and empty && input2.getNnz()==0 ) //nnz input known and empty { //remove unnecessary right indexing Hop hnew = HopRewriteUtils.createDataGenOp( input1, 0); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); HopRewriteUtils.cleanupUnreferenced(hi, input2); hi = hnew; LOG.debug("Applied removeEmptyLeftIndexing"); } } return hi; } private Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos) { if( hi instanceof LeftIndexingOp ) //left indexing op { Hop input = hi.getInput().get(1); //rhs matrix/frame if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims { //equal dims of left indexing input and output -> no need for indexing //remove unnecessary right indexing HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied removeUnnecessaryLeftIndexing"); } } return hi; } private Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) { boolean applied = false; //pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame if( hi instanceof LeftIndexingOp //first lix && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi) && hi.getInput().get(0) instanceof LeftIndexingOp //second lix && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0)) && hi.getInput().get(0).getParent().size()==1 //first lix is single consumer && hi.getInput().get(0).getInput().get(0).getDim2() == 2 ) //two column matrix { Hop input2 = hi.getInput().get(1); //rhs matrix Hop pred2 = hi.getInput().get(4); //cl=cu Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix Hop pred1 = hi.getInput().get(0).getInput().get(4); //cl=cu if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2 && input1.getDataType()!=DataType.SCALAR && input2.getDataType()!=DataType.SCALAR ) { //create new cbind operation and rewrite inputs BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND); HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); hi = bop; applied = true; } } //pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B) if( !applied && hi instanceof LeftIndexingOp //first lix && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi) && hi.getInput().get(0) instanceof LeftIndexingOp //second lix && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0)) && hi.getInput().get(0).getParent().size()==1 //first lix is single consumer && hi.getInput().get(0).getInput().get(0).getDim1() == 2 ) //two column matrix { Hop input2 = hi.getInput().get(1); //rhs matrix Hop pred2 = hi.getInput().get(2); //rl=ru Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix Hop pred1 = hi.getInput().get(0).getInput().get(2); //rl=ru if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2 && input1.getDataType()!=DataType.SCALAR && input2.getDataType()!=DataType.SCALAR ) { //create new cbind operation and rewrite inputs BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND); HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); hi = bop; applied = true; LOG.debug("Applied fuseLeftIndexingChainToAppend2 (line "+hi.getBeginLine()+")"); } } return hi; } private Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) { if( hi instanceof UnaryOp && ((UnaryOp)hi).isCumulativeUnaryOperation() ) { Hop input = hi.getInput().get(0); //input matrix if( HopRewriteUtils.isDimsKnown(input) //dims input known && input.getDim1()==1 ) //1 row { OpOp1 op = ((UnaryOp)hi).getOp(); //remove unnecessary unary cumsum operator HopRewriteUtils.replaceChildReference(parent, hi, input, pos); hi = input; LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op); } } return hi; } private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) { if( hi instanceof ReorgOp ) { ReorgOp rop = (ReorgOp) hi; Hop input = hi.getInput().get(0); boolean apply = false; //equal dims of reshape input and output -> no need for reshape because //byrow always refers to both input/output and hence gives the same result apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input)); //1x1 dimensions of transpose/reshape -> no need for reorg apply |= ((rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE) && rop.getDim1()==1 && rop.getDim2()==1); if( apply ) { HopRewriteUtils.replaceChildReference(parent, hi, input, pos); hi = input; LOG.debug("Applied removeUnnecessaryReorg."); } } return hi; } private Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos) { if( hi instanceof BinaryOp ) //binary cell operation { Hop right = hi.getInput().get(1); //check for column replication if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen && right.getInput().get(1) instanceof DataGenOp && ((DataGenOp)right.getInput().get(1)).getOp()==DataGenMethod.RAND && ((DataGenOp)right.getInput().get(1)).hasConstantValue(1d) && right.getInput().get(1).getDim1() == 1 //row vector for replication && right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary { //remove unnecessary outer product HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1 ); HopRewriteUtils.cleanupUnreferenced(right); LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")"); } //check for row replication else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen && right.getInput().get(0) instanceof DataGenOp && ((DataGenOp)right.getInput().get(0)).hasConstantValue(1d) && right.getInput().get(0).getDim2() == 1 //colunm vector for replication && right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary { //remove unnecessary outer product HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1 ); HopRewriteUtils.cleanupUnreferenced(right); LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")"); } } return hi; } @SuppressWarnings("unchecked") private Hop fuseDatagenAndReorgOperation(Hop parent, Hop hi, int pos) { if( HopRewriteUtils.isTransposeOperation(hi) && hi.getInput().get(0) instanceof DataGenOp //datagen && hi.getInput().get(0).getParent().size()==1 ) //transpose only consumer { DataGenOp dop = (DataGenOp)hi.getInput().get(0); if( (dop.getOp() == DataGenMethod.RAND || dop.getOp() == DataGenMethod.SINIT) && (dop.getDim1()==1 || dop.getDim2()==1 )) { //relink all parents and dataop (remove transpose) HopRewriteUtils.removeAllChildReferences(hi); ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); for( int i=0; i<parents.size(); i++ ) { Hop lparent = parents.get(i); int ppos = HopRewriteUtils.getChildReferencePos(lparent, hi); HopRewriteUtils.removeChildReferenceByPos(lparent, hi, ppos); HopRewriteUtils.addChildReference(lparent, dop, pos); } //flip rows/cols attributes in datagen HashMap<String, Integer> rparams = dop.getParamIndexMap(); int pos1 = rparams.get(DataExpression.RAND_ROWS); int pos2 = rparams.get(DataExpression.RAND_COLS); rparams.put(DataExpression.RAND_ROWS, pos2); rparams.put(DataExpression.RAND_COLS, pos1); dop.refreshSizeInformation(); hi = dop; LOG.debug("Applied fuseDatagenReorgOperation."); } } return hi; } @SuppressWarnings("unchecked") private Hop simplifyColwiseAggregate( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof AggUnaryOp ) { AggUnaryOp uhi = (AggUnaryOp)hi; Hop input = uhi.getInput().get(0); if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { if( uhi.getDirection() == Direction.Col ) { if( input.getDim1() == 1 ) { if (uhi.getOp() == AggOp.VAR) { // For the column variance aggregation, if the input is a row vector, // the column variances will each be zero. // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros. Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0); HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos); HopRewriteUtils.cleanupUnreferenced(hi, input); hi = emptyRow; LOG.debug("Applied simplifyColwiseAggregate for colVars"); } else { // All other valid column aggregations over a row vector will result // in the row vector itself. // Therefore, remove unnecessary col aggregation for 1 row. HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied simplifyColwiseAggregate1"); } } else if( input.getDim2() == 1 ) { //get old parents (before creating cast over aggregate) ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); //simplify col-aggregate to full aggregate uhi.setDirection(Direction.RowCol); uhi.setDataType(DataType.SCALAR); //create cast to keep same output datatype UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX); //rehang cast under all parents for( Hop p : parents ) { int ix = HopRewriteUtils.getChildReferencePos(p, hi); HopRewriteUtils.replaceChildReference(p, hi, cast, ix); } hi = cast; LOG.debug("Applied simplifyColwiseAggregate2"); } } } } return hi; } @SuppressWarnings("unchecked") private Hop simplifyRowwiseAggregate( Hop parent, Hop hi, int pos ) throws HopsException { if( hi instanceof AggUnaryOp ) { AggUnaryOp uhi = (AggUnaryOp)hi; Hop input = uhi.getInput().get(0); if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { if( uhi.getDirection() == Direction.Row ) { if( input.getDim2() == 1 ) { if (uhi.getOp() == AggOp.VAR) { // For the row variance aggregation, if the input is a column vector, // the row variances will each be zero. // Therefore, perform a rewrite from ROWVAR(X) to a column vector of // zeros. Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0); HopRewriteUtils.replaceChildReference(parent, hi, emptyCol, pos); HopRewriteUtils.cleanupUnreferenced(hi, input); // replace current HOP with new empty column HOP hi = emptyCol; LOG.debug("Applied simplifyRowwiseAggregate for rowVars"); } else { // All other valid row aggregations over a column vector will result // in the column vector itself. // Therefore, remove unnecessary row aggregation for 1 col HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied simplifyRowwiseAggregate1"); } } else if( input.getDim1() == 1 ) { //get old parents (before creating cast over aggregate) ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); //simplify row-aggregate to full aggregate uhi.setDirection(Direction.RowCol); uhi.setDataType(DataType.SCALAR); //create cast to keep same output datatype UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX); //rehang cast under all parents for( Hop p : parents ) { int ix = HopRewriteUtils.getChildReferencePos(p, hi); HopRewriteUtils.replaceChildReference(p, hi, cast, ix); } hi = cast; LOG.debug("Applied simplifyRowwiseAggregate2"); } } } } return hi; } private Hop simplifyColSumsMVMult( Hop parent, Hop hi, int pos ) throws HopsException { //colSums(X*Y) -> t(Y) %*% X, if Y col vector; additional transpose later //removed by other rewrite if unnecessary, i.e., if Y==t(Z) if( hi instanceof AggUnaryOp ) { AggUnaryOp uhi = (AggUnaryOp)hi; Hop input = uhi.getInput().get(0); if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col //colsums && HopRewriteUtils.isBinary(input, OpOp2.MULT) ) //b(*) { Hop left = input.getInput().get(0); Hop right = input.getInput().get(1); if( left.getDim1()>1 && left.getDim2()>1 && right.getDim1()>1 && right.getDim2()==1 ) // MV (col vector) { //create new operators ReorgOp trans = HopRewriteUtils.createTranspose(right); AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left); //relink new child HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos); HopRewriteUtils.cleanupUnreferenced(uhi, input); hi = mmult; LOG.debug("Applied simplifyColSumsMVMult"); } } } return hi; } private Hop simplifyRowSumsMVMult( Hop parent, Hop hi, int pos ) throws HopsException { //rowSums(X * Y) -> X %*% t(Y), if Y row vector; additional transpose later //removed by other rewrite if unnecessary, i.e., if Y==t(Z) if( hi instanceof AggUnaryOp ) { AggUnaryOp uhi = (AggUnaryOp)hi; Hop input = uhi.getInput().get(0); if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row //rowsums && HopRewriteUtils.isBinary(input, OpOp2.MULT) ) //b(*) { Hop left = input.getInput().get(0); Hop right = input.getInput().get(1); if( left.getDim1()>1 && left.getDim2()>1 && right.getDim1()==1 && right.getDim2()>1 ) // MV (row vector) { //create new operators ReorgOp trans = HopRewriteUtils.createTranspose(right); AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans); //relink new child HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos); HopRewriteUtils.cleanupUnreferenced(hi, input); hi = mmult; LOG.debug("Applied simplifyRowSumsMVMult"); } } } return hi; } private Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) throws HopsException { //e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace) if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol ) { AggUnaryOp uhi = (AggUnaryOp)hi; Hop input = uhi.getInput().get(0); if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){ if( input.getDim1()==1 && input.getDim2()==1 ) { UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR); //remove unnecessary aggregation HopRewriteUtils.replaceChildReference(parent, hi, cast, pos); hi = cast; LOG.debug("Applied simplifyUnncessaryAggregate"); } } } return hi; } private Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof AggUnaryOp ) { AggUnaryOp uhi = (AggUnaryOp)hi; Hop input = uhi.getInput().get(0); if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_AGGREGATE) ){ if( HopRewriteUtils.isEmpty(input) ) { Hop hnew = null; if( uhi.getDirection() == Direction.RowCol ) hnew = new LiteralOp(0.0); else if( uhi.getDirection() == Direction.Col ) hnew = HopRewriteUtils.createDataGenOp(uhi, input, 0); //nrow(uhi)=1 else //if( uhi.getDirection() == Direction.Row ) hnew = HopRewriteUtils.createDataGenOp(input, uhi, 0); //ncol(uhi)=1 //add new child to parent input HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied simplifyEmptyAggregate"); } } } return hi; } private Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof UnaryOp ) { UnaryOp uhi = (UnaryOp)hi; Hop input = uhi.getInput().get(0); if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_UNARY) ){ if( HopRewriteUtils.isEmpty(input) ) { //create literal add it to parent Hop hnew = HopRewriteUtils.createDataGenOp(input, 0); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied simplifyEmptyUnaryOperation"); } } } return hi; } private Hop simplifyEmptyReorgOperation(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof ReorgOp ) { ReorgOp rhi = (ReorgOp)hi; Hop input = rhi.getInput().get(0); if( HopRewriteUtils.isEmpty(input) ) //empty input { //reorg-operation-specific rewrite Hop hnew = null; if( rhi.getOp() == ReOrgOp.TRANSPOSE ) hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0); else if( rhi.getOp() == ReOrgOp.REV ) hnew = HopRewriteUtils.createDataGenOp(input, 0); else if( rhi.getOp() == ReOrgOp.DIAG ) { if( HopRewriteUtils.isDimsKnown(input) ) { if( input.getDim2()==1 ) //diagv2m hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0); else //diagm2v hnew = HopRewriteUtils.createDataGenOpByVal( HopRewriteUtils.createValueHop(input,true), new LiteralOp(1), 0); } } else if( rhi.getOp() == ReOrgOp.RESHAPE ) hnew = HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1), rhi.getInput().get(2), 0); //modify dag if one of the above rules applied if( hnew != null ){ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied simplifyEmptyReorgOperation"); } } } return hi; } private Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos) throws HopsException { //order(X, indexreturn=FALSE) -> matrix(0,nrow(X),1) //order(X, indexreturn=TRUE) -> seq(1,nrow(X),1) if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) { ReorgOp rhi = (ReorgOp)hi; Hop input = rhi.getInput().get(0); if( HopRewriteUtils.isEmpty(input) ) //empty input { //reorg-operation-specific rewrite Hop hnew = null; boolean ixret = false; if( rhi.getInput().get(3) instanceof LiteralOp ) //index return known { ixret = HopRewriteUtils.getBooleanValue((LiteralOp)rhi.getInput().get(3)); if( ixret ) hnew = HopRewriteUtils.createSeqDataGenOp(input); else hnew = HopRewriteUtils.createDataGenOp(input, 0); } //modify dag if one of the above rules applied if( hnew != null ){ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied simplifyEmptySortOperation (indexreturn="+ixret+")."); } } } return hi; } private Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos) throws HopsException { if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> matrix(0, ) { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); if( HopRewriteUtils.isEmpty(left) //one input empty || HopRewriteUtils.isEmpty(right) ) { //create datagen and add it to parent Hop hnew = HopRewriteUtils.createDataGenOp(left, right, 0); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied simplifyEmptyMatrixMult"); } } return hi; } private Hop simplifyIdentityRepMatrixMult(Hop parent, Hop hi, int pos) throws HopsException { if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> X, if y is matrix(1,1,1) { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); // X %*% y -> X if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 && //scalar right right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND && ((DataGenOp)right).hasConstantValue(1.0)) //matrix(1,) { HopRewriteUtils.replaceChildReference(parent, hi, left, pos); hi = left; LOG.debug("Applied simplifyIdentiyMatrixMult"); } } return hi; } private Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos) throws HopsException { if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); // y %*% X -> as.scalar(y) * X if( HopRewriteUtils.isDimsKnown(left) && left.getDim1()==1 && left.getDim2()==1 ) //scalar left { UnaryOp cast = HopRewriteUtils.createUnary(left, OpOp1.CAST_AS_SCALAR); BinaryOp mult = HopRewriteUtils.createBinary(cast, right, OpOp2.MULT); //add mult to parent HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = mult; LOG.debug("Applied simplifyScalarMatrixMult1"); } // X %*% y -> X * as.scalar(y) else if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right { UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR); BinaryOp mult = HopRewriteUtils.createBinary(cast, left, OpOp2.MULT); //add mult to parent HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = mult; LOG.debug("Applied simplifyScalarMatrixMult2"); } } return hi; } private Hop simplifyMatrixMultDiag(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); // diag(X) %*% Y -> X * Y / diag(X) %*% Y -> Y * X // previously rep required for the second case: diag(X) %*% Y -> (X%*%ones) * Y if( left instanceof ReorgOp && ((ReorgOp)left).getOp()==ReOrgOp.DIAG //left diag && HopRewriteUtils.isDimsKnown(left) && left.getDim2()>1 ) //diagV2M { //System.out.println("diag mm rewrite: dim2(right)="+right.getDim2()); if( right.getDim2()==1 ) //right column vector { //create binary operation over input and right Hop input = left.getInput().get(0); //diag input hnew = HopRewriteUtils.createBinary(input, right, OpOp2.MULT); LOG.debug("Applied simplifyMatrixMultDiag1"); } else if( right.getDim2()>1 ) //multi column vector { //create binary operation over input and right; in contrast to above rewrite, //we need to switch the order because MV binary cell operations require vector on the right Hop input = left.getInput().get(0); //diag input hnew = HopRewriteUtils.createBinary(right, input, OpOp2.MULT); //NOTE: previously to MV binary cell operations we replicated the left //(if moderate number of columns: 2), but this is no longer required LOG.debug("Applied simplifyMatrixMultDiag2"); } } //notes: similar rewrites would be possible for the right side as well, just transposed into the right alignment } //if one of the above rewrites applied if( hnew !=null ){ //add mult to parent HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = hnew; } return hi; } private Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) { if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.DIAG && hi.getDim2()==1 ) //diagM2V { Hop hi2 = hi.getInput().get(0); if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y { Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); //create new operators (incl refresh size inside for transpose) ReorgOp trans = HopRewriteUtils.createTranspose(right); BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT); AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row); //rehang new subdag under parent node HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = rowSum; LOG.debug("Applied simplifyDiagMatrixMult"); } } return hi; } private Hop simplifySumDiagToTrace(Hop hi) { if( hi instanceof AggUnaryOp ) { AggUnaryOp au = (AggUnaryOp) hi; if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum { Hop hi2 = au.getInput().get(0); if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V { Hop hi3 = hi2.getInput().get(0); //remove diag operator HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0); HopRewriteUtils.cleanupUnreferenced(hi2); //change sum to trace au.setOp( AggOp.TRACE ); LOG.debug("Applied simplifySumDiagToTrace"); } } } return hi; } @SuppressWarnings("unchecked") private Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos) { //diag(X)*7 --> diag(X*7) in order to (1) reduce required memory for b(*) and //(2) in order to make the binary operation more efficient (dense vector vs sparse matrix) if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) ) { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); boolean applyLeft = false; boolean applyRight = false; //left input is diag if( left instanceof ReorgOp && ((ReorgOp)left).getOp()==ReOrgOp.DIAG && left.getParent().size()==1 //binary op only parent && left.getInput().get(0).getDim2()==1 //col vector && right.getDataType() == DataType.SCALAR ) { applyLeft = true; } else if( right instanceof ReorgOp && ((ReorgOp)right).getOp()==ReOrgOp.DIAG && right.getParent().size()==1 //binary op only parent && right.getInput().get(0).getDim2()==1 //col vector && left.getDataType() == DataType.SCALAR ) { applyRight = true; } //perform actual rewrite if( applyLeft || applyRight ) { //remove all parent links to binary op (since we want to reorder //we cannot just look at the current parent) ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); ArrayList<Integer> parentspos = new ArrayList<Integer>(); for(Hop lparent : parents) { int lpos = HopRewriteUtils.getChildReferencePos(lparent, hi); HopRewriteUtils.removeChildReferenceByPos(lparent, hi, lpos); parentspos.add(lpos); } //rewire binop-diag-input into diag-binop-input if( applyLeft ) { Hop input = left.getInput().get(0); HopRewriteUtils.removeChildReferenceByPos(hi, left, 0); HopRewriteUtils.removeChildReferenceByPos(left, input, 0); HopRewriteUtils.addChildReference(left, hi, 0); HopRewriteUtils.addChildReference(hi, input, 0); hi.refreshSizeInformation(); hi = left; } else if ( applyRight ) { Hop input = right.getInput().get(0); HopRewriteUtils.removeChildReferenceByPos(hi, right, 1); HopRewriteUtils.removeChildReferenceByPos(right, input, 0); HopRewriteUtils.addChildReference(right, hi, 0); HopRewriteUtils.addChildReference(hi, input, 1); hi.refreshSizeInformation(); hi = right; } //relink all parents to the diag operation for( int i=0; i<parents.size(); i++ ) { Hop lparent = parents.get(i); int lpos = parentspos.get(i); HopRewriteUtils.addChildReference(lparent, hi, lpos); } LOG.debug("Applied pushdownBinaryOperationOnDiag."); } } return hi; } /** * patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B) * * @param parent the parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator */ private Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) { //all patterns headed by full sum over binary operation if( hi instanceof AggUnaryOp //full sum root over binaryop && ((AggUnaryOp)hi).getDirection()==Direction.RowCol && ((AggUnaryOp)hi).getOp() == AggOp.SUM && hi.getInput().get(0) instanceof BinaryOp && hi.getInput().get(0).getParent().size()==1 ) //single parent { BinaryOp bop = (BinaryOp) hi.getInput().get(0); Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B) && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX ) { OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B) || bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B) ? bop.getOp() : null; if( applyOp != null ) { //create new subdag sum(A) bop sum(B) AggUnaryOp sum1 = HopRewriteUtils.createSum(left); AggUnaryOp sum2 = HopRewriteUtils.createSum(right); BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp); //rewire new subdag HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos); HopRewriteUtils.cleanupUnreferenced(hi, bop); hi = newBin; LOG.debug("Applied pushdownSumOnAdditiveBinary."); } } } return hi; } /** * Searches for weighted squared loss expressions and replaces them with a quaternary operator. * Currently, this search includes the following three patterns: * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) * * NOTE: We include transpose into the pattern because during runtime we need to compute * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation * without additional memory requirements for internal transpose. * * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon * as we generalized the runtime or hop/lop compilation. * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) throws HopsException { //NOTE: there might be also a general simplification without custom operator //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2 Hop hnew = null; if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol && ((AggUnaryOp)hi).getOp() == AggOp.SUM //all patterns rooted by sum() && hi.getInput().get(0) instanceof BinaryOp //all patterns subrooted by binary op && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult { BinaryOp bop = (BinaryOp) hi.getInput().get(0); boolean appliedPattern = false; //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2) if( bop.getOp()==OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2) { Hop W = bop.getInput().get(0); Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V)) if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv && tmp.getInput().get(0).getDataType() == DataType.MATRIX ) { //a) sum (W * (X - U %*% t(V)) ^ 2) int uvIndex = -1; if( tmp.getInput().get(1) instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { uvIndex = 1; } //b) sum (W * (U %*% t(V) - X) ^ 2) else if(tmp.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { uvIndex = 0; } if( uvIndex >= 0 ) //rewrite match { Hop X = tmp.getInput().get((uvIndex==0)?1:0); Hop U = tmp.getInput().get(uvIndex).getInput().get(0); Hop V = tmp.getInput().get(uvIndex).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) { V = HopRewriteUtils.createTranspose(V); } else{ V = V.getInput().get(0); } //handle special case of post_nz if( HopRewriteUtils.isNonZeroIndicator(W, X) ){ W = new LiteralOp(1); } //construct quaternary hop hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true); HopRewriteUtils.setOutputParametersForScalar(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")"); } } } //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2) if( !appliedPattern && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX) { Hop lleft = bop.getInput().get(0).getInput().get(0); Hop lright = bop.getInput().get(0).getInput().get(1); //a) sum ((X - W * (U %*% t(V))) ^ 2) int wuvIndex = -1; if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){ wuvIndex = 1; } //b) sum ((W * (U %*% t(V)) - X) ^ 2) else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){ wuvIndex = 0; } if( wuvIndex >= 0 ) //rewrite match { Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0); Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V))) if( ((BinaryOp)tmp).getOp()==OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { Hop W = tmp.getInput().get(0); Hop U = tmp.getInput().get(1).getInput().get(0); Hop V = tmp.getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) { V = HopRewriteUtils.createTranspose(V); } else { V = V.getInput().get(0); } hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); HopRewriteUtils.setOutputParametersForScalar(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")"); } } } //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) //alternative pattern: sum (((U %*% t(V)) - X) ^ 2) if( !appliedPattern && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX) { Hop lleft = bop.getInput().get(0).getInput().get(0); Hop lright = bop.getInput().get(0).getInput().get(1); //a) sum ((X - (U %*% t(V))) ^ 2) int uvIndex = -1; if( lright instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { uvIndex = 1; } //b) sum (((U %*% t(V)) - X) ^ 2) else if( lleft instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { uvIndex = 0; } if( uvIndex >= 0 ) //rewrite match { Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0); Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V)) Hop W = new LiteralOp(1); //no weighting Hop U = tmp.getInput().get(0); Hop V = tmp.getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) { V = HopRewriteUtils.createTranspose(V); } else { V = V.getInput().get(0); } hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false); HopRewriteUtils.setOutputParametersForScalar(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")"); } } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; } return hi; } private Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) //all patterns subrooted by W * && hi.getDim2() > 1 //not applied for vector-vector mult && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof UnaryOp ) //sigmoid/log { UnaryOp uop = (UnaryOp) hi.getInput().get(1); boolean appliedPattern = false; //Pattern 1) W * sigmoid(Y%*%t(X)) (basic) if( uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) ) { Hop W = hi.getInput().get(0); Hop Y = uop.getInput().get(0).getInput().get(0); Hop tX = uop.getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid1 (line "+hi.getBeginLine()+")"); } //Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus) if( !appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe( (LiteralOp)uop.getInput().get(0).getInput().get(0))==0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true)) { Hop W = hi.getInput().get(0); Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0); Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid2 (line "+hi.getBeginLine()+")"); } //Pattern 3) W * log(sigmoid(Y%*%t(X))) (log) if( !appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) ) { Hop W = hi.getInput().get(0); Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0); Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid3 (line "+hi.getBeginLine()+")"); } //Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus) if( !appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS) ) { BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0); if( bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true)) { Hop W = hi.getInput().get(0); Hop Y = bop.getInput().get(1).getInput().get(0); Hop tX = bop.getInput().get(1).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(tX) ) { tX = HopRewriteUtils.createTranspose(tX); } else tX = tX.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedSigmoid4 (line "+hi.getBeginLine()+")"); } } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; } return hi; } private Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; boolean appliedPattern = false; //left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)' //note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops) if( HopRewriteUtils.isMatrixMultiply(hi) && (hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) || hi.getInput().get(1) instanceof BinaryOp && hi.getDim2() > 1 //not applied for vector-vector mult && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) ) { Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); //Pattern 1) t(U) %*% (W/(U%*%t(V))) //alternative pattern: t(U) %*% (W*(U%*%t(V))) if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); Hop U = right.getInput().get(1).getInput().get(0); Hop V = right.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(left, U) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT; hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); //add output transpose for efficient target indexing (redundant t() removed by other rewrites) hnew = HopRewriteUtils.createTranspose(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")"); } } //Pattern 1e) t(U) %*% (W/(U%*%t(V) + x)) if( !appliedPattern && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = right.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(left, U) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 3, false, false); // 3=>DIV_LEFT_EPS hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); //add output transpose for efficient target indexing (redundant t() removed by other rewrites) hnew = HopRewriteUtils.createTranspose(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM1e (line "+hi.getBeginLine()+")"); } } //Pattern 2) (W/(U%*%t(V))) %*% V //alternative pattern: (W*(U%*%t(V))) %*% V if( !appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); Hop U = left.getInput().get(1).getInput().get(0); Hop V = left.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(right, V) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = right; else V = V.getInput().get(0); boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT; hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")"); } } //Pattern 2e) (W/(U%*%t(V) + x)) %*% V if( !appliedPattern && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = left.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(right, V) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = right; else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 4, false, false); // 4=>DIV_RIGHT_EPS hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM2e (line "+hi.getBeginLine()+")"); } } //Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X)) if( !appliedPattern && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = right.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint && HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); //add output transpose for efficient target indexing (redundant t() removed by other rewrites) hnew = HopRewriteUtils.createTranspose(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")"); } } //Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V if( !appliedPattern && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = left.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint && HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint { if( !HopRewriteUtils.isTransposeOperation(V) ) V = right; else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")"); } } //Pattern 5) t(U) %*% (W*(U%*%t(V)-X)) if( !appliedPattern && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = right.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); //note: x and w exchanged compared to patterns 1-4, 7 hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 1, true, true); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); //add output transpose for efficient target indexing (redundant t() removed by other rewrites) hnew = HopRewriteUtils.createTranspose(hnew); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")"); } } //Pattern 6) (W*(U%*%t(V)-X)) %*% V if( !appliedPattern && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); Hop X = left.getInput().get(1).getInput().get(1); if( HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint { if( !HopRewriteUtils.isTransposeOperation(V) ) V = right; else V = V.getInput().get(0); //note: x and w exchanged compared to patterns 1-4, 7 hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 2, true, true); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM6 (line "+hi.getBeginLine()+")"); } } } //Pattern 7) (W*(U%*%t(V))) if( !appliedPattern && HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv && hi.getDim2() > 1 //not applied for vector-vector mult && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = hi.getInput().get(0); Hop U = hi.getInput().get(1).getInput().get(0); Hop V = hi.getInput().get(1).getInput().get(1); //for this basic pattern, we're more conservative and only apply wdivmm if //the factors are not known to be sparse if( !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V) ) { if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")"); } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; } return hi; } private Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; boolean appliedPattern = false; if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol && ((AggUnaryOp)hi).getOp() == AggOp.SUM //pattern rooted by sum() && hi.getInput().get(0) instanceof BinaryOp //pattern subrooted by binary op && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult { BinaryOp bop = (BinaryOp) hi.getInput().get(0); Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); //Pattern 1) sum( X * log(U %*% t(V))) if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) //prevent mb && HopRewriteUtils.isUnary(right, OpOp1.LOG) && right.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT { Hop X = left; Hop U = right.getInput().get(0).getInput().get(0); Hop V = right.getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false); hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock()); appliedPattern = true; LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")"); } //Pattern 2) sum( X * log(U %*% t(V) + eps)) if( !appliedPattern && bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0),true)) { Hop X = left; Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0); Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1); Hop eps = right.getInput().get(0).getInput().get(1); if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, false); // 1 => BASIC_EPS hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock()); LOG.debug("Applied simplifyWeightedCEMMEps (line "+hi.getBeginLine()+")"); } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; } return hi; } private Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) throws HopsException { Hop hnew = null; boolean appliedPattern = false; //Pattern 1) (W*uop(U%*%t(V))) if( hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv && hi.getDim2() > 1 //not applied for vector-vector mult && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = hi.getInput().get(0); Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0); Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1); boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT; OpOp1 op = ((UnaryOp)hi.getInput().get(1)).getOp(); if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")"); } //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops if( !appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv && hi.getDim2() > 1 //not applied for vector-vector mult && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY) ) { Hop left = hi.getInput().get(1).getInput().get(0); Hop right = hi.getInput().get(1).getInput().get(1); Hop abop = null; //pattern 2a) matrix-scalar operations if( right.getDataType()==DataType.SCALAR && right instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right)==2 //pow2, mult2 && left instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(left.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { abop = left; } //pattern 2b) scalar-matrix operations else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==2 //mult2 && ((BinaryOp)hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(right.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { abop = right; } if( abop != null ) { Hop W = hi.getInput().get(0); Hop U = abop.getInput().get(0); Hop V = abop.getInput().get(1); boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT; OpOp2 op = ((BinaryOp)hi.getInput().get(1)).getOp(); if( !HopRewriteUtils.isTransposeOperation(V) ) V = HopRewriteUtils.createTranspose(V); else V = V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); appliedPattern = true; LOG.debug("Applied simplifyWeightedUnaryMM2 (line "+hi.getBeginLine()+")"); } } //relink new hop into original position if( hnew != null ) { HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; } return hi; } /** * NOTE: dot-product-sum could be also applied to sum(a*b). However, we * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm * a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always * beneficial. * * @param parent parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ private Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) throws HopsException { //sum(v^2)/sum(v1*v2) --> as.scalar(t(v)%*%v) in order to exploit tsmm vector dotproduct //w/o materialization of intermediates if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full aggregate && hi.getInput().get(0).getDim2() == 1 ) //vector (for correctness) { Hop baLeft = null; Hop baRight = null; Hop hi2 = hi.getInput().get(0); //check for ^2 w/o multiple consumers //check for sum(v^2), might have been rewritten from sum(v*v) if( HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)hi2.getInput().get(1))==2 && hi2.getParent().size() == 1 ) //no other consumer than sum { Hop input = hi2.getInput().get(0); baLeft = input; baRight = input; } //check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum && hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) ) { baLeft = hi2.getInput().get(0); baRight = hi2.getInput().get(1); } //perform actual rewrite (if necessary) if( baLeft != null && baRight != null ) { //create new operator chain ReorgOp trans = HopRewriteUtils.createTranspose(baLeft); AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight); UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR); //rehang new subdag under parent node HopRewriteUtils.replaceChildReference(parent, hi, cast, pos); HopRewriteUtils.cleanupUnreferenced(hi, hi2); hi = cast; LOG.debug("Applied simplifyDotProductSum."); } } return hi; } /** * Replace SUM(X^2) with a fused SUM_SQ(X) HOP. * * @param parent Parent HOP for which hi is an input. * @param hi Current HOP for potential rewrite. * @param pos Position of hi in parent's list of inputs. * * @return Either hi or the rewritten HOP replacing it. * * @throws HopsException if HopsException occurs */ private Hop fuseSumSquared(Hop parent, Hop hi, int pos) throws HopsException { // if SUM if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) { Hop sumInput = hi.getInput().get(0); // if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP if( HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) { Hop x = sumInput.getInput().get(0); // if X is NOT a column vector if (x.getDim2() > 1) { // perform rewrite from SUM(POW(X,2)) to SUM_SQ(X) Direction dir = ((AggUnaryOp) hi).getDirection(); AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir); HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos); HopRewriteUtils.cleanupUnreferenced(hi, sumInput); hi = sumSq; LOG.debug("Applied fuseSumSquared."); } } } return hi; } private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) { //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY if( hi instanceof BinaryOp && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) ) { BinaryOp bop = (BinaryOp) hi; Hop left = bop.getInput().get(0); Hop right = bop.getInput().get(1); Hop ternop = null; //pattern (a) X + s*Y -> X +* sY if( bop.getOp() == OpOp2.PLUS && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && right.getParent().size() == 1 ) //single consumer s*Y { Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT); LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")"); } //pattern (b) s*Y + X -> X +* sY else if( bop.getOp() == OpOp2.PLUS && right.getDataType()==DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && left.getParent().size() == 1 //single consumer s*Y && HopRewriteUtils.isEqualSize(left, right)) //correctness matrix-vector { Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT); LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")"); } //pattern (c) X - s*Y -> X -* sY else if( bop.getOp() == OpOp2.MINUS && left.getDataType()==DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && right.getParent().size() == 1 ) //single consumer s*Y { Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT); LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")"); } //rewire parent-child operators if rewrite applied if( ternop != null ) { HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos); hi = ternop; } } return hi; } private Hop simplifyEmptyBinaryOperation(Hop parent, Hop hi, int pos) throws HopsException { if( hi instanceof BinaryOp ) //b(?) X Y { BinaryOp bop = (BinaryOp) hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { Hop hnew = null; //NOTE: these rewrites of binary cell operations need to be aware that right is //potentially a vector but the result is of the size of left boolean notBinaryMV = HopRewriteUtils.isNotMatrixVectorBinaryOperation(bop); switch( bop.getOp() ){ //X * Y -> matrix(0,nrow(X),ncol(X)); case MULT: { if( HopRewriteUtils.isEmpty(left) ) //empty left and size known hnew = HopRewriteUtils.createDataGenOp(left, left, 0); else if( HopRewriteUtils.isEmpty(right) //empty right and right not a vector && right.getDim1()>1 && right.getDim2()>1 ) { hnew = HopRewriteUtils.createDataGenOp(right, right, 0); } else if( HopRewriteUtils.isEmpty(right) )//empty right and right potentially a vector hnew = HopRewriteUtils.createDataGenOp(left, left, 0); break; } case PLUS: { if( HopRewriteUtils.isEmpty(left) && HopRewriteUtils.isEmpty(right) ) //empty left/right and size known hnew = HopRewriteUtils.createDataGenOp(left, left, 0); else if( HopRewriteUtils.isEmpty(left) && notBinaryMV ) //empty left hnew = right; else if( HopRewriteUtils.isEmpty(right) ) //empty right hnew = left; break; } case MINUS: { if( HopRewriteUtils.isEmpty(left) && notBinaryMV ) { //empty left HopRewriteUtils.removeChildReference(hi, left); HopRewriteUtils.addChildReference(hi, new LiteralOp(0), 0); hnew = hi; } else if( HopRewriteUtils.isEmpty(right) ) //empty and size known hnew = left; break; } default: hnew = null; } if( hnew != null ) { //create datagen and add it to parent HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; LOG.debug("Applied simplifyEmptyBinaryOperation"); } } } return hi; } /** * This is rewrite tries to reorder minus operators from inputs of matrix * multiply to its output because the output is (except for outer products) * usually significantly smaller. Furthermore, this rewrite is a precondition * for the important hops-lops rewrite of transpose-matrixmult if the transpose * is hidden under the minus. * * NOTE: in this rewrite we need to modify the links to all parents because we * remove existing links of subdags and hence affect all consumers. * * @param parent the parent high-level operator * @param hi high-level operator * @param pos position * @return high-level operator * @throws HopsException if HopsException occurs */ @SuppressWarnings("unchecked") private Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos) throws HopsException { if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y { Hop hileft = hi.getInput().get(0); Hop hiright = hi.getInput().get(1); if( HopRewriteUtils.isBinary(hileft, OpOp2.MINUS) //X=-Z && hileft.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)hileft.getInput().get(0))==0.0 && hi.dimsKnown() && hileft.getInput().get(1).dimsKnown() //size comparison && HopRewriteUtils.compareSize(hi, hileft.getInput().get(1)) < 0 ) { Hop hi2 = hileft.getInput().get(1); //remove link from matrixmult to minus HopRewriteUtils.removeChildReference(hi, hileft); //get old parents (before creating minus over matrix mult) ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); //create new operators BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS); //rehang minus under all parents for( Hop p : parents ) { int ix = HopRewriteUtils.getChildReferencePos(p, hi); HopRewriteUtils.removeChildReference(p, hi); HopRewriteUtils.addChildReference(p, minus, ix); } //rehang child of minus under matrix mult HopRewriteUtils.addChildReference(hi, hi2, 0); //cleanup if only consumer of minus HopRewriteUtils.cleanupUnreferenced(hileft); hi = minus; LOG.debug("Applied reorderMinusMatrixMult (line "+hi.getBeginLine()+")."); } else if( HopRewriteUtils.isBinary(hiright, OpOp2.MINUS) //X=-Z && hiright.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)hiright.getInput().get(0))==0.0 && hi.dimsKnown() && hiright.getInput().get(1).dimsKnown() //size comparison && HopRewriteUtils.compareSize(hi, hiright.getInput().get(1)) < 0 ) { Hop hi2 = hiright.getInput().get(1); //remove link from matrixmult to minus HopRewriteUtils.removeChildReference(hi, hiright); //get old parents (before creating minus over matrix mult) ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); //create new operators BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS); //rehang minus under all parents for( Hop p : parents ) { int ix = HopRewriteUtils.getChildReferencePos(p, hi); HopRewriteUtils.removeChildReference(p, hi); HopRewriteUtils.addChildReference(p, minus, ix); } //rehang child of minus under matrix mult HopRewriteUtils.addChildReference(hi, hi2, 1); //cleanup if only consumer of minus HopRewriteUtils.cleanupUnreferenced(hiright); hi = minus; LOG.debug("Applied reorderMinusMatrixMult (line "+hi.getBeginLine()+")."); } } return hi; } private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) { //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product //colSums(A%*%B) -> colSums(A)%*%B //rowSums(A%*%B) -> A%*%rowSums(B) //-- if not dot product, not applied since aggregate removed //-- if sum not the only consumer, not applied to prevent redundancy if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum && hi.getInput().get(0) instanceof AggBinaryOp //A%*%B && (hi.getInput().get(0).getDim1()>1 || hi.getInput().get(0).getDim2()>1) //not dot product && hi.getInput().get(0).getParent().size()==1 ) //not multiple consumers of matrix mult { Hop hi2 = hi.getInput().get(0); Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); //remove link from parent to matrix mult HopRewriteUtils.removeChildReference(hi, hi2); //create new operators Hop root = null; //pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product if( ((AggUnaryOp)hi).getDirection() == Direction.RowCol ) { AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col); ReorgOp trans = HopRewriteUtils.createTranspose(colSum); AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row); root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT); LOG.debug("Applied simplifySumMatrixMult RC."); } //colSums(A%*%B) -> colSums(A)%*%B else if( ((AggUnaryOp)hi).getDirection() == Direction.Col ) { AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col); root = HopRewriteUtils.createMatrixMultiply(colSum, right); LOG.debug("Applied simplifySumMatrixMult C."); } //rowSums(A%*%B) -> A%*%rowSums(B) else if( ((AggUnaryOp)hi).getDirection() == Direction.Row ) { AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row); root = HopRewriteUtils.createMatrixMultiply(left, rowSum); LOG.debug("Applied simplifySumMatrixMult R."); } //rehang new subdag under current node (keep hi intact) HopRewriteUtils.addChildReference(hi, root, 0); hi.refreshSizeInformation(); //cleanup if only consumer of intermediate HopRewriteUtils.cleanupUnreferenced(hi2); } return hi; } private Hop simplifyScalarMVBinaryOperation(Hop hi) throws HopsException { if( hi instanceof BinaryOp && ((BinaryOp)hi).supportsMatrixScalarOperations() //e.g., X * s && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1).getDataType()==DataType.MATRIX ) { Hop right = hi.getInput().get(1); //X * s -> X * as.scalar(s) if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right { //remove link to right child and introduce cast UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR); HopRewriteUtils.replaceChildReference(hi, right, cast, 1); LOG.debug("Applied simplifyScalarMVBinaryOperation."); } } return hi; } private Hop simplifyNnzComputation(Hop parent, Hop hi, int pos) throws HopsException { //sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum && ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.NOTEQUAL) ) { Hop ppred = hi.getInput().get(0); Hop X = null; if( ppred.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(0))==0 ) { X = ppred.getInput().get(1); } else if( ppred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(1))==0 ) { X = ppred.getInput().get(0); } //apply rewrite if known nnz if( X != null && X.getNnz() > 0 ){ Hop hnew = new LiteralOp(X.getNnz()); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = hnew; LOG.debug("Applied simplifyNnzComputation."); } } return hi; } private Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) throws HopsException { //nrow(X) -> literal(nrow(X)), ncol(X) -> literal(ncol(X)), if respective dims known //(this rewrite aims to remove unnecessary data dependencies to X which trigger computation //even if the intermediate is otherwise not required, e.g., when part of a fused operator) if( hi instanceof UnaryOp ) { if( ((UnaryOp)hi).getOp()==OpOp1.NROW && hi.getInput().get(0).getDim1()>0 ) { Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1()); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false); HopRewriteUtils.cleanupUnreferenced(hi); hi = hnew; LOG.debug("Applied simplifyNrowComputation."); } else if( ((UnaryOp)hi).getOp()==OpOp1.NCOL && hi.getInput().get(0).getDim2()>0 ) { Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2()); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false); HopRewriteUtils.cleanupUnreferenced(hi); hi = hnew; LOG.debug("Applied simplifyNcolComputation."); } } return hi; } private Hop simplifyTableSeqExpand(Hop parent, Hop hi, int pos) throws HopsException { //pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) //note: this rewrite supports both left/right sequence if( hi instanceof TernaryOp && hi.getInput().size()==5 //table without weights && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) //i.e., weight of 1 && hi.getInput().get(3) instanceof LiteralOp && hi.getInput().get(4) instanceof LiteralOp) { Hop first = hi.getInput().get(0); Hop second = hi.getInput().get(1); //pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1) if( HopRewriteUtils.isBasic1NSequence(first, second, true) && second.dimsKnown() && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), second.getDim1()) ) { //setup input parameter hops HashMap<String,Hop> args = new HashMap<String,Hop>(); args.put("target", second); args.put("max", hi.getInput().get(4)); args.put("dir", new LiteralOp("cols")); args.put("ignore", new LiteralOp(false)); args.put("cast", new LiteralOp(true)); //create new hop ParameterizedBuiltinOp pbop = HopRewriteUtils .createParameterizedBuiltinOp(second, args, ParamBuiltinOp.REXPAND); HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = pbop; LOG.debug("Applied simplifyTableSeqExpand1 (line "+hi.getBeginLine()+")"); } //pattern b: table(v, seq(1,nrow(v)), m, nrow(v)) else if( HopRewriteUtils.isBasic1NSequence(second, first, true) && first.dimsKnown() && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(4), first.getDim1()) ) { //setup input parameter hops HashMap<String,Hop> args = new HashMap<String,Hop>(); args.put("target", first); args.put("max", hi.getInput().get(3)); args.put("dir", new LiteralOp("rows")); args.put("ignore", new LiteralOp(false)); args.put("cast", new LiteralOp(true)); //create new hop ParameterizedBuiltinOp pbop = HopRewriteUtils .createParameterizedBuiltinOp(first, args, ParamBuiltinOp.REXPAND); HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = pbop; LOG.debug("Applied simplifyTableSeqExpand2 (line "+hi.getBeginLine()+")"); } } return hi; } }