/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.sysml.conf.CompilerConfig.ConfigType; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.ForStatement; import org.apache.sysml.parser.ForStatementBlock; import org.apache.sysml.parser.FunctionStatement; import org.apache.sysml.parser.FunctionStatementBlock; import org.apache.sysml.parser.IfStatement; import org.apache.sysml.parser.IfStatementBlock; import org.apache.sysml.parser.LanguageException; import org.apache.sysml.parser.ParForStatementBlock; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.WhileStatement; import org.apache.sysml.parser.WhileStatementBlock; /** * This program rewriter applies a variety of rule-based rewrites * on all hop dags of the given program in one pass over the entire * program. * */ public class ProgramRewriter { private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName()); //internal local debug level private static final boolean LDEBUG = false; private static final boolean CHECK = false; private ArrayList<HopRewriteRule> _dagRuleSet = null; private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null; static{ // for internal debugging only if( LDEBUG ) { Logger.getLogger("org.apache.sysml.hops.rewrite") .setLevel((Level) Level.DEBUG); } } public ProgramRewriter() { // by default which is used during initial compile // apply all (static and dynamic) rewrites this( true, true ); } public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); //initialize StatementBlock rewrite ruleSet (with fixed rewrite order) _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); //STATIC REWRITES (which do not rely on size information) if( staticRewrites ) { //add static HOP DAG rewrite rules _dagRuleSet.add( new RewriteTransientWriteParentHandling() ); _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); _dagRuleSet.add( new RewriteCompressedReblock() ); _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock //add statment block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS ) _sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock if( ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS) ) _sbRuleSet.add( new RewriteSplitDagDataDependentOperators() ); if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop) _sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes) if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE ) _sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() ); } // DYNAMIC REWRITES (which do require size information) if( dynamicRewrites ) { _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) { _dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse } } // cleanup after all rewrites applied // (newly introduced operators, introduced redundancy after rewrites w/ multiple parents) _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); } /** * Construct a program rewriter for a given rewrite which is passed from outside. * * @param rewrites the HOP rewrite rules */ public ProgramRewriter( HopRewriteRule... rewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); for( HopRewriteRule rewrite : rewrites ) _dagRuleSet.add( rewrite ); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); } /** * Construct a program rewriter for a given rewrite which is passed from outside. * * @param rewrites the statement block rewrite rules */ public ProgramRewriter( StatementBlockRewriteRule... rewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); for( StatementBlockRewriteRule rewrite : rewrites ) _sbRuleSet.add( rewrite ); } /** * Construct a program rewriter for the given rewrite sets which are passed from outside. * * @param hRewrites HOP rewrite rules * @param sbRewrites statement block rewrite rules */ public ProgramRewriter( ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _dagRuleSet.addAll( hRewrites ); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); _sbRuleSet.addAll( sbRewrites ); } public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) throws LanguageException, HopsException { ProgramRewriteStatus state = new ProgramRewriteStatus(); // for each namespace, handle function statement blocks for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); rewriteStatementBlockHopDAGs(fsblock, state); rewriteStatementBlock(fsblock, state); } // handle regular statement blocks in "main" method for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); rewriteStatementBlockHopDAGs(current, state); } dmlp.setStatementBlocks( rewriteStatementBlocks(dmlp.getStatementBlocks(), state) ); return state; } public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) throws LanguageException, HopsException { //ensure robustness for calls from outside if( state == null ) state = new ProgramRewriteStatus(); if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)current; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (StatementBlock sb : fstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) current; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state)); for (StatementBlock sb : wstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) current; IfStatement istmt = (IfStatement)isb.getStatement(0); isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state)); for (StatementBlock sb : istmt.getIfBody()) rewriteStatementBlockHopDAGs(sb, state); for (StatementBlock sb : istmt.getElseBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof ForStatementBlock) //incl parfor { ForStatementBlock fsb = (ForStatementBlock) current; ForStatement fstmt = (ForStatement)fsb.getStatement(0); fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state)); fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state)); fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state)); for (StatementBlock sb : fstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else //generic (last-level) { current.set_hops( rewriteHopDAGs(current.get_hops(), state) ); } } public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { for( HopRewriteRule r : _dagRuleSet ) { Hop.resetVisitStatus( roots ); //reset for each rule roots = r.rewriteHopDAGs(roots, state); if( CHECK ) { LOG.info("Validation after: "+r.getClass().getName()); HopDagValidator.validateHopDag(roots); } } return roots; } public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { if( root == null ) return root; for( HopRewriteRule r : _dagRuleSet ) { root.resetVisitStatus(); //reset for each rule root = r.rewriteHopDAG(root, state); if( CHECK ) { LOG.info("Validation after: "+r.getClass().getName()); HopDagValidator.validateHopDag(root); } } return root; } public ArrayList<StatementBlock> rewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus state ) throws HopsException { //ensure robustness for calls from outside if( state == null ) state = new ProgramRewriteStatus(); ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); //rewrite statement blocks (with potential expansion) for( StatementBlock sb : sbs ) tmp.addAll( rewriteStatementBlock(sb, state) ); //copy results into original collection sbs.clear(); sbs.addAll( tmp ); return sbs; } private ArrayList<StatementBlock> rewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status ) throws HopsException { ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); ret.add(sb); //recursive invocation if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) ); } else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); wstmt.setBody( rewriteStatementBlocks( wstmt.getBody(), status ) ); } else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); istmt.setIfBody( rewriteStatementBlocks( istmt.getIfBody(), status ) ); istmt.setElseBody( rewriteStatementBlocks( istmt.getElseBody(), status ) ); } else if (sb instanceof ForStatementBlock) //incl parfor { //maintain parfor context information (e.g., for checkpointing) boolean prestatus = status.isInParforContext(); if( sb instanceof ParForStatementBlock ) status.setInParforContext(true); ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fstmt = (ForStatement)fsb.getStatement(0); fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) ); status.setInParforContext(prestatus); } //apply rewrite rules for( StatementBlockRewriteRule r : _sbRuleSet ) { ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); for( StatementBlock sbc : ret ) tmp.addAll( r.rewriteStatementBlock(sbc, status) ); //take over set of rewritten sbs ret.clear(); ret.addAll(tmp); } return ret; } }