/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.hops.rewrite; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; import com.ibm.bi.dml.hops.Hop; import com.ibm.bi.dml.hops.HopsException; import com.ibm.bi.dml.hops.OptimizerUtils; import com.ibm.bi.dml.parser.DMLProgram; import com.ibm.bi.dml.parser.ForStatement; import com.ibm.bi.dml.parser.ForStatementBlock; import com.ibm.bi.dml.parser.FunctionStatement; import com.ibm.bi.dml.parser.FunctionStatementBlock; import com.ibm.bi.dml.parser.IfStatement; import com.ibm.bi.dml.parser.IfStatementBlock; import com.ibm.bi.dml.parser.LanguageException; import com.ibm.bi.dml.parser.ParForStatementBlock; import com.ibm.bi.dml.parser.StatementBlock; import com.ibm.bi.dml.parser.WhileStatement; import com.ibm.bi.dml.parser.WhileStatementBlock; /** * This program rewriter applies a variety of rule-based rewrites * on all hop dags of the given program in one pass over the entire * program. * */ public class ProgramRewriter { private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName()); //internal local debug level private static final boolean LDEBUG = false; private static final boolean CHECK = false; private ArrayList<HopRewriteRule> _dagRuleSet = null; private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null; static{ // for internal debugging only if( LDEBUG ) { Logger.getLogger("com.ibm.bi.dml.hops.rewrite") .setLevel((Level) Level.DEBUG); } } public ProgramRewriter() { // by default which is used during initial compile // apply all (static and dynamic) rewrites this( true, true ); } public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); //initialize StatementBlock rewrite ruleSet (with fixed rewrite order) _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); //STATIC REWRITES (which do not rely on size information) if( staticRewrites ) { //add static HOP DAG rewrite rules _dagRuleSet.add( new RewriteTransientWriteParentHandling() ); _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock //add statment block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS ) _sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock if( OptimizerUtils.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS ) _sbRuleSet.add( new RewriteSplitDagDataDependentOperators() ); if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop) _sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes) } // DYNAMIC REWRITES (which do require size information) if( dynamicRewrites ) { _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) { _dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse } //reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and //(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse) if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications } } /** * Construct a program rewriter for a given rewrite which is passed from outside. * * @param rewrite */ public ProgramRewriter( HopRewriteRule rewrite ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _dagRuleSet.add( rewrite ); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); } /** * Construct a program rewriter for a given rewrite which is passed from outside. * * @param rewrite */ public ProgramRewriter( StatementBlockRewriteRule rewrite ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); _sbRuleSet.add( rewrite ); } /** * Construct a program rewriter for the given rewrite sets which are passed from outside. * * @param rewrite */ public ProgramRewriter( ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _dagRuleSet.addAll( hRewrites ); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); _sbRuleSet.addAll( sbRewrites ); } /** * * @param dmlp * @return * @throws LanguageException * @throws HopsException */ public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) throws LanguageException, HopsException { ProgramRewriteStatus state = new ProgramRewriteStatus(); // for each namespace, handle function statement blocks for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); rewriteStatementBlockHopDAGs(fsblock, state); rewriteStatementBlock(fsblock, state); } // handle regular statement blocks in "main" method for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); rewriteStatementBlockHopDAGs(current, state); } dmlp.setStatementBlocks( rewriteStatementBlocks(dmlp.getStatementBlocks(), state) ); return state; } /** * * @param current * @throws LanguageException * @throws HopsException */ public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) throws LanguageException, HopsException { //ensure robustness for calls from outside if( state == null ) state = new ProgramRewriteStatus(); if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)current; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (StatementBlock sb : fstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) current; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state)); for (StatementBlock sb : wstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) current; IfStatement istmt = (IfStatement)isb.getStatement(0); isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state)); for (StatementBlock sb : istmt.getIfBody()) rewriteStatementBlockHopDAGs(sb, state); for (StatementBlock sb : istmt.getElseBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof ForStatementBlock) //incl parfor { ForStatementBlock fsb = (ForStatementBlock) current; ForStatement fstmt = (ForStatement)fsb.getStatement(0); fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state)); fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state)); fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state)); for (StatementBlock sb : fstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else //generic (last-level) { current.set_hops( rewriteHopDAGs(current.get_hops(), state) ); } } /** * * @param roots * @throws LanguageException * @throws HopsException */ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { for( HopRewriteRule r : _dagRuleSet ) { Hop.resetVisitStatus( roots ); //reset for each rule roots = r.rewriteHopDAGs(roots, state); if( CHECK ) { LOG.info("Validation after: "+r.getClass().getName()); HopDagValidator.validateHopDag(roots); } } return roots; } /** * * @param root * @throws LanguageException * @throws HopsException */ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { for( HopRewriteRule r : _dagRuleSet ) { root.resetVisitStatus(); //reset for each rule root = r.rewriteHopDAG(root, state); if( CHECK ) { LOG.info("Validation after: "+r.getClass().getName()); HopDagValidator.validateHopDag(root); } } return root; } /** * * @param sbs * @return * @throws HopsException */ public ArrayList<StatementBlock> rewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus state ) throws HopsException { //ensure robustness for calls from outside if( state == null ) state = new ProgramRewriteStatus(); ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); //rewrite statement blocks (with potential expansion) for( StatementBlock sb : sbs ) tmp.addAll( rewriteStatementBlock(sb, state) ); //copy results into original collection sbs.clear(); sbs.addAll( tmp ); return sbs; } /** * * @param sb * @return * @throws HopsException */ private ArrayList<StatementBlock> rewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status ) throws HopsException { ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); ret.add(sb); //recursive invocation if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) ); } else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); wstmt.setBody( rewriteStatementBlocks( wstmt.getBody(), status ) ); } else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); istmt.setIfBody( rewriteStatementBlocks( istmt.getIfBody(), status ) ); istmt.setElseBody( rewriteStatementBlocks( istmt.getElseBody(), status ) ); } else if (sb instanceof ForStatementBlock) //incl parfor { //maintain parfor context information (e.g., for checkpointing) boolean prestatus = status.isInParforContext(); if( sb instanceof ParForStatementBlock ) status.setInParforContext(true); ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fstmt = (ForStatement)fsb.getStatement(0); fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) ); status.setInParforContext(prestatus); } //apply rewrite rules for( StatementBlockRewriteRule r : _sbRuleSet ) { ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); for( StatementBlock sbc : ret ) tmp.addAll( r.rewriteStatementBlock(sbc, status) ); //take over set of rewritten sbs ret.clear(); ret.addAll(tmp); } return ret; } }