/* * Copyright (c) 2013, Ecole Polytechnique Fédérale de Lausanne * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * Neither the name of the Ecole Polytechnique Fédérale de Lausanne nor the names of its * contributors may be used to endorse or promote products derived from this * software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY * WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF * SUCH DAMAGE. */ package net.sf.orcc.backends.transform; import java.util.ArrayList; import java.util.List; import net.sf.orcc.ir.Arg; import net.sf.orcc.ir.ArgByRef; import net.sf.orcc.ir.ArgByVal; import net.sf.orcc.ir.Block; import net.sf.orcc.ir.BlockBasic; import net.sf.orcc.ir.BlockIf; import net.sf.orcc.ir.BlockWhile; import net.sf.orcc.ir.Def; import net.sf.orcc.ir.ExprBinary; import net.sf.orcc.ir.ExprBool; import net.sf.orcc.ir.ExprFloat; import net.sf.orcc.ir.ExprInt; import net.sf.orcc.ir.ExprString; import net.sf.orcc.ir.ExprVar; import net.sf.orcc.ir.Expression; import net.sf.orcc.ir.InstAssign; import net.sf.orcc.ir.InstCall; import net.sf.orcc.ir.InstLoad; import net.sf.orcc.ir.InstStore; import net.sf.orcc.ir.IrFactory; import net.sf.orcc.ir.OpBinary; import net.sf.orcc.ir.Procedure; import net.sf.orcc.ir.Use; import net.sf.orcc.ir.Var; import net.sf.orcc.ir.util.AbstractIrVisitor; import net.sf.orcc.ir.util.IrUtil; import net.sf.orcc.util.OrccLogger; import net.sf.orcc.util.Void; /** * A simple loop unrolling transformation * * @author Endri Bezati * */ public class LoopUnrolling extends AbstractIrVisitor<Void> { private class InnerExpressionVisitor extends AbstractIrVisitor<Expression> { private final int idxLiteral; private final Var index; public InnerExpressionVisitor(Var index, int idxLiteral) { super(true); this.index = index; this.idxLiteral = idxLiteral; } @Override public Expression caseExprBinary(ExprBinary expr) { Expression e1 = doSwitch(expr.getE1()); Expression e2 = doSwitch(expr.getE2()); return IrFactory.eINSTANCE.createExprBinary(e1, expr.getOp(), e2, IrUtil.copy(expr.getType())); } @Override public Expression caseExprBool(ExprBool object) { return IrUtil.copy(object); } @Override public Expression caseExprFloat(ExprFloat object) { return IrUtil.copy(object); } @Override public Expression caseExprInt(ExprInt object) { return IrUtil.copy(object); } @Override public Expression caseExprString(ExprString object) { return IrUtil.copy(object); } @Override public Expression caseExprVar(ExprVar object) { Var var = object.getUse().getVariable(); if (var == index) { return IrFactory.eINSTANCE.createExprInt(idxLiteral); } return IrUtil.copy(object); } } private class InstructionUnroller extends AbstractIrVisitor<Void> { @Override public Void caseInstAssign(InstAssign assign) { Var target = assign.getTarget().getVariable(); // Test if Assign(index, index + 1) if (target == index) { Expression value = assign.getValue(); if (value.isExprBinary()) { ExprBinary exprBin = (ExprBinary) value; Expression e1 = exprBin.getE1(); Expression e2 = exprBin.getE2(); // Assign(index, index + 1) if (e1.isExprVar() && e2.isExprInt() && exprBin.getOp() == OpBinary.PLUS) { Var idx = ((ExprVar) e1).getUse().getVariable(); if (idx == index) { IrUtil.delete(assign); } } } } else { for (int i = 0; i <= repetition; i++) { Expression value = assign.getValue(); InnerExpressionVisitor innerVisitor = new InnerExpressionVisitor( index, i); Expression unrolledValue = innerVisitor.doSwitch(value); InstAssign instAssign = IrFactory.eINSTANCE .createInstAssign(target, unrolledValue); int idxInst = assign.getBlock().indexOf(assign); BlockBasic block = unrollBlocks.get(i); block.add(idxInst, instAssign); } } return null; } @Override public Void caseInstCall(InstCall call) { for (int i = 0; i <= repetition; i++) { Var target = call.getTarget().getVariable(); List<Arg> arguments = call.getArguments(); List<Arg> unrolledArguments = new ArrayList<Arg>( arguments.size()); InnerExpressionVisitor innerVisitor = new InnerExpressionVisitor( index, i); for (Arg arg : arguments) { if (arg.isByVal()) { ArgByVal argByVal = (ArgByVal) arg; Expression expr = innerVisitor.doSwitch(argByVal .getValue()); Arg unrolledArg = IrFactory.eINSTANCE .createArgByVal(expr); unrolledArguments.add(unrolledArg); } else { ArgByRef argByRef = (ArgByRef) arg; Var var = argByRef.getUse().getVariable(); List<Expression> indexes = argByRef.getIndexes(); List<Expression> newIndexes = new ArrayList<Expression>( indexes.size()); for (Expression expression : indexes) { Expression expr = innerVisitor.doSwitch(expression); newIndexes.add(expr); } Use use = IrFactory.eINSTANCE.createUse(var); ArgByRef unrolledArg = IrFactory.eINSTANCE .createArgByRef(); unrolledArg.setUse(use); unrolledArg.getIndexes().addAll(newIndexes); } } InstCall instCall = IrFactory.eINSTANCE.createInstCall(); Def defTarget = IrFactory.eINSTANCE.createDef(target); instCall.setTarget(defTarget); instCall.setProcedure(call.getProcedure()); instCall.getArguments().addAll(unrolledArguments); int idxInst = call.getBlock().indexOf(call); BlockBasic block = unrollBlocks.get(i); block.add(idxInst, instCall); } return null; } @Override public Void caseInstLoad(InstLoad load) { for (int i = 0; i <= repetition; i++) { Var target = load.getTarget().getVariable(); Var source = load.getSource().getVariable(); InnerExpressionVisitor innerVisitor = new InnerExpressionVisitor( index, i); List<Expression> indexes = load.getIndexes(); List<Expression> newIndexes = new ArrayList<Expression>( indexes.size()); for (Expression expression : indexes) { Expression expr = innerVisitor.doSwitch(expression); newIndexes.add(expr); } InstLoad instLoad = IrFactory.eINSTANCE.createInstLoad(target, source, newIndexes); int idxInst = load.getBlock().indexOf(load); BlockBasic block = unrollBlocks.get(i); block.add(idxInst, instLoad); } return null; } @Override public Void caseInstStore(InstStore store) { for (int i = 0; i <= repetition; i++) { // Create new Store Instruction Var target = store.getTarget().getVariable(); InnerExpressionVisitor innerVisitor = new InnerExpressionVisitor( index, i); List<Expression> indexes = store.getIndexes(); List<Expression> newIndexes = new ArrayList<Expression>( indexes.size()); for (Expression expression : indexes) { Expression expr = innerVisitor.doSwitch(expression); newIndexes.add(expr); } Expression value = innerVisitor.doSwitch(store.getValue()); InstStore instStore = IrFactory.eINSTANCE.createInstStore( target, newIndexes, value); int idxInst = store.getBlock().indexOf(store); BlockBasic block = unrollBlocks.get(i); block.add(idxInst, instStore); } return null; } } /** The Loop index **/ private Var index = null; /** The Loop repetition **/ private int repetition = 0; /** The unrolled loops **/ private List<BlockBasic> unrollBlocks; @Override public Void caseBlockWhile(BlockWhile blockWhile) { super.caseBlockWhile(blockWhile); if (blockWhile.hasAttribute("unroll")) { unroll(blockWhile); } return null; } /** * This method test if a blockWhile contains only BlockBasic blocks * * @param blockWhile * @return */ protected boolean testSimpleLoop(BlockWhile blockWhile) { for (Block block : blockWhile.getBlocks()) { if (block.isBlockIf() || block.isBlockWhile()) { return false; } } return true; } protected void unroll(BlockWhile blockWhile) { Expression condition = blockWhile.getCondition(); if (condition.isExprBinary()) { ExprBinary exprBin = (ExprBinary) condition; Expression e1 = exprBin.getE1(); Expression e2 = exprBin.getE2(); OpBinary op = exprBin.getOp(); if (e1.isExprVar() && e2.isExprInt() && (op == OpBinary.LT || op == OpBinary.LE)) { index = ((ExprVar) e1).getUse().getVariable(); if (op == OpBinary.LT) { repetition = ((ExprInt) e2).getIntValue() - 1; } else { repetition = ((ExprInt) e2).getIntValue(); } if (testSimpleLoop(blockWhile)) { unrollBlocks = new ArrayList<BlockBasic>(repetition); for (int i = 0; i <= repetition; i++) { unrollBlocks.add(i, IrFactory.eINSTANCE.createBlockBasic()); } // Unroll Loop InstructionUnroller unroller = new InstructionUnroller(); unroller.doSwitch(blockWhile.getBlocks()); List<Block> containerBlocks = null; // Take the unroll blocks and add them before the blockWhile if (blockWhile.eContainer() instanceof Procedure) { Procedure container = (Procedure) blockWhile .eContainer(); containerBlocks = container.getBlocks(); } else if (blockWhile.eContainer() instanceof BlockWhile) { BlockWhile container = (BlockWhile) blockWhile .eContainer(); containerBlocks = container.getBlocks(); } else if (blockWhile.eContainer() instanceof BlockIf) { BlockIf container = (BlockIf) blockWhile.eContainer(); // Test if then or else blocks contains this blockWhile if (container.getThenBlocks().contains(blockWhile)) { containerBlocks = container.getThenBlocks(); } else { containerBlocks = container.getElseBlocks(); } } if (containerBlocks != null) { BlockBasic combinedBlock = IrFactory.eINSTANCE.createBlockBasic(); for (int i = 0; i <= repetition; i++) { combinedBlock.getInstructions().addAll(unrollBlocks.get(i).getInstructions()); } int idxBlockWhile = containerBlocks.indexOf(blockWhile); containerBlocks.add(idxBlockWhile,combinedBlock); IrUtil.delete(blockWhile); } } else { OrccLogger .warnln("LoopUnrolling cannot process on loop at line: " + blockWhile.getLineNumber()); } } else { OrccLogger .warnln("LoopUnrolling cannot process on loop at line: " + blockWhile.getLineNumber()); } } } }