package net.sf.orcc.backends.c.dal.transform;
import java.util.ArrayList;
import java.util.List;
import net.sf.orcc.df.Action;
import net.sf.orcc.df.Actor;
import net.sf.orcc.ir.Block;
import net.sf.orcc.ir.BlockBasic;
import net.sf.orcc.ir.BlockIf;
import net.sf.orcc.ir.ExprBinary;
import net.sf.orcc.ir.ExprUnary;
import net.sf.orcc.ir.ExprVar;
import net.sf.orcc.ir.Expression;
import net.sf.orcc.ir.InstAssign;
import net.sf.orcc.ir.InstCall;
import net.sf.orcc.ir.InstLoad;
import net.sf.orcc.ir.InstReturn;
import net.sf.orcc.ir.Instruction;
import net.sf.orcc.ir.IrFactory;
import net.sf.orcc.ir.OpBinary;
import net.sf.orcc.ir.OpUnary;
import net.sf.orcc.ir.Procedure;
import net.sf.orcc.ir.Var;
import net.sf.orcc.ir.util.AbstractIrVisitor;
import net.sf.orcc.ir.util.ExpressionPrinter;
import net.sf.orcc.ir.util.IrUtil;
import net.sf.orcc.util.OrccLogger;
/**
* Perform if-conversion (branch predication) on guards
*
* @author Jani Boutellier
*
*/
public class IfConverter {
private Detector irVisitor;
private List<Var> varList;
private Expression preExpression;
private Expression conditionExpression;
private Expression thenExpression;
private Expression elseExpression;
private Expression postExpressionCopy;
private Expression tmpIfParent;
private ExpressionPrinter exPr;
private class Detector extends AbstractIrVisitor<Void> {
public Detector() {
super(true);
}
@Override
public Void caseExprVar(ExprVar expr) {
boolean found = false;
for (Var var : varList) {
if (var.getName().equals(expr.getUse().getVariable().getName())) {
found = true;
}
}
if (!found) {
varList.add(expr.getUse().getVariable());
}
return null;
}
}
private class VarFinder extends AbstractIrVisitor<Void> {
private String name;
private List<Expression> exprList;
public VarFinder(String name) {
this.name = name;
this.exprList = new ArrayList<Expression>();
}
private void checkOperand(Expression expr, Expression operand) {
if (operand.isExprVar()) {
if (name.equals(((ExprVar)operand).getUse().getVariable().getName())) {
exprList.add(expr);
}
}
}
@Override
public Void caseExprUnary(ExprUnary expr) {
doSwitch(expr.getExpr());
checkOperand(expr, expr.getExpr());
return null;
}
@Override
public Void caseExprBinary(ExprBinary expr) {
doSwitch(expr.getE1());
checkOperand(expr, expr.getE1());
doSwitch(expr.getE2());
checkOperand(expr, expr.getE2());
return null;
}
public List<Expression> getList() {
return exprList;
}
}
private List<VarInstructionPair> loadList;
private class VarInstructionPair {
Var var;
Instruction inst;
VarInstructionPair(Var var, Instruction inst) {
this.var = var;
this.inst = inst;
}
}
private Instruction getInstruction(Var var) {
for (VarInstructionPair pair : loadList) {
if (var.getName().equals(pair.var.getName())) {
return pair.inst;
}
}
OrccLogger.warnln("No load instruction found for variable "
+ var.getName());
return null;
}
public IfConverter () {
this.irVisitor = new Detector();
this.loadList = new ArrayList<VarInstructionPair>();
this.varList = new ArrayList<Var>();
this.exPr = new ExpressionPrinter();
}
private boolean extractExpressions(Action action) {
for (Block b : action.getScheduler().getBlocks()) {
if (b.isBlockBasic()) {
if (conditionExpression == null) {
preExpression = extractComputeExpression(((BlockBasic) b).getInstructions());
if (preExpression != null) {
OrccLogger.warnln("IfConverter: PreExpression contains compute expression");
}
} else {
if (postExpressionCopy != null) {
OrccLogger.warnln("IfConverter: Multiple PostExpressions");
}
postExpressionCopy = IrUtil.copy(extractComputeExpression(((BlockBasic) b).getInstructions()));
}
} else if (b.isBlockIf()) {
conditionExpression = ((BlockIf) b).getCondition();
if (((BlockIf) b).getThenBlocks().size() < 1) {
OrccLogger.warnln("IfConverter: ThenBlock empty in "
+ action.getName());
} else if (((BlockIf) b).getThenBlocks().size() > 1) {
OrccLogger.warnln("IfConverter: more than one ThenBlock in "
+ action.getName());
} else {
for (Block bb : ((BlockIf) b).getThenBlocks()) {
thenExpression = extractComputeExpression(((BlockBasic) bb).getInstructions());
if (thenExpression == null) {
thenExpression = createAssignFromLastLoad(((BlockBasic) bb).getInstructions());
if (thenExpression == null) {
OrccLogger.warnln("Failed to extract compute expression from ThenBlock of " + action.getName());
}
}
}
}
if (((BlockIf) b).getElseBlocks() != null) {
if (((BlockIf) b).getElseBlocks().size() < 1) {
OrccLogger.warnln("IfConverter: ElseBlock empty in "
+ action.getName());
} else if (((BlockIf) b).getElseBlocks().size() > 1) {
OrccLogger.warnln("IfConverter: more than one ElseBlock in "
+ action.getName());
} else {
for (Block bb : ((BlockIf) b).getElseBlocks()) {
elseExpression = extractComputeExpression(((BlockBasic) bb).getInstructions());
if (elseExpression == null) {
elseExpression = createAssignFromLastLoad(((BlockBasic) bb).getInstructions());
if (elseExpression == null) {
OrccLogger.warnln("Failed to extract compute expression from ElseBlock of " + action.getName());
}
}
}
}
} else {
OrccLogger.warnln("IfConverter: no compute expression in ElseBlock of"
+ action.getName());
}
if(extractComputeExpression(((BlockIf) b).getJoinBlock().getInstructions()) != null) {
OrccLogger.warnln("IfConverter: unsupported JoinBlock encountered "
+ action.getName());
}
} else {
OrccLogger.warnln("IfConverter: unsupported block type in guard of "
+ action.getName());
return false;
}
}
return true;
}
private Expression extractComputeExpression(List<Instruction> guard) {
Expression compute = null;
for (Instruction i : guard) {
if (i.isInstAssign()) {
return ((InstAssign) i).getValue();
} else if (i.isInstReturn()) {
return ((InstReturn) i).getValue();
}
}
return compute;
}
private Expression createAssignFromLastLoad(List<Instruction> guard) {
Expression compute = null;
Instruction lastInstruction = guard.get(guard.size() - 1);
if (lastInstruction.isInstLoad()) {
compute = IrFactory.eINSTANCE.createExprVar(((InstLoad) lastInstruction).getSource().getVariable());
}
return compute;
}
private void extractLoadBlockBasic(BlockBasic b) {
for (Instruction i : b.getInstructions()) {
if (i.isInstLoad()) {
loadList.add(new VarInstructionPair(
((InstLoad) i).getTarget().getVariable(), i));
} else if (i.isInstCall()) {
loadList.add(new VarInstructionPair(
((InstCall) i).getTarget().getVariable(), i));
}
}
}
private boolean extractLoads(Action action) {
for (Block b : action.getScheduler().getBlocks()) {
if (b.isBlockBasic()) {
extractLoadBlockBasic((BlockBasic) b);
} else if (b.isBlockIf()) {
for (Block bb : ((BlockIf) b).getThenBlocks()) {
extractLoadBlockBasic((BlockBasic)bb);
}
if (((BlockIf) b).getElseBlocks() != null) {
for (Block bb : ((BlockIf) b).getElseBlocks()) {
extractLoadBlockBasic((BlockBasic)bb);
}
}
extractLoadBlockBasic(((BlockIf) b).getJoinBlock());
} else {
OrccLogger.warnln("IfConverter: unsupported block type in guard of "
+ action.getName());
return false;
}
}
return true;
}
private boolean hasIfBlock(Procedure proc) {
if (proc != null) {
for (Block b : proc.getBlocks()) {
if (b.isBlockIf()) {
return true;
}
}
}
return false;
}
public void ifConvertGuards(Actor actor) {
for (Action a : actor.getActions()) {
if (hasIfBlock(a.getScheduler())) {
Procedure newScheduler = doIfConversion(a);
if (newScheduler != null) {
a.setScheduler(newScheduler);
}
}
}
}
private void extractTmpIf(Expression expr) {
VarFinder varFinder = new VarFinder("tmp_if");
varFinder.doSwitch(expr);
if (varFinder.getList().size() > 1) {
OrccLogger.warnln("IfConverter: more than one instances of tmp_if in expression");
} else if (varFinder.getList().size() == 1) {
tmpIfParent = varFinder.getList().get(0);
} else {
if (expr.isExprVar()) {
tmpIfParent = expr;
} else {
OrccLogger.warnln("IfConverter: Error locating join point in " + exPr.doSwitch(expr));
}
}
}
private Procedure doIfConversion(Action action) {
if (extractLoads(action)) {
preExpression = null;
conditionExpression = null;
thenExpression = null;
elseExpression = null;
postExpressionCopy = null;
tmpIfParent = null;
extractExpressions(action);
extractTmpIf(postExpressionCopy);
return createNewGuard(action);
}
return null;
}
private Expression negateExpression(Expression expr) {
return IrFactory.eINSTANCE.createExprUnary(OpUnary.LOGIC_NOT, expr,
expr.getType());
}
private Expression andExpression(Expression expr1, Expression expr2) {
return IrFactory.eINSTANCE.createExprBinary(expr1, OpBinary.LOGIC_AND,
expr2, expr1.getType());
}
private Expression orExpression(Expression expr1, Expression expr2) {
return IrFactory.eINSTANCE.createExprBinary(expr1, OpBinary.LOGIC_OR,
expr2, expr1.getType());
}
private void replaceSubExpr1(ExprBinary binPar, Expression compute) {
if (binPar.getE1().isExprVar()) {
if(((ExprVar)binPar.getE1()).getUse().getVariable().getName().equals("tmp_if")) {
binPar.setE1(compute);
}
}
}
private void replaceSubExpr2(ExprBinary binPar, Expression compute) {
if (binPar.getE2().isExprVar()) {
if(((ExprVar)binPar.getE2()).getUse().getVariable().getName().equals("tmp_if")) {
binPar.setE2(compute);
}
}
}
private Expression replaceTmpIf(Expression tmpIfPar, Expression compute) {
if (tmpIfPar.isExprVar()) {
return compute;
} else if (tmpIfPar.isExprUnary()) {
((ExprUnary)tmpIfPar).setExpr(compute);
} else {
replaceSubExpr1((ExprBinary)tmpIfPar, compute);
replaceSubExpr2((ExprBinary)tmpIfPar, compute);
}
return tmpIfPar;
}
private Procedure createNewGuard(Action action) {
Procedure proc = IrFactory.eINSTANCE.createProcedure("isSchedulable_"
+ action.getName(), 0,
IrFactory.eINSTANCE.createTypeBool());
Var result = proc.newTempLocalVariable(
IrFactory.eINSTANCE.createTypeBool(), "result");
Expression leftExpression = andExpression(IrUtil.copy(conditionExpression), IrUtil.copy(thenExpression));
Expression rightExpression = andExpression(negateExpression(IrUtil.copy(conditionExpression)), IrUtil.copy(elseExpression));
Expression computeExpression = orExpression(leftExpression, rightExpression);
computeExpression = replaceTmpIf(tmpIfParent, computeExpression);
computeExpression = postExpressionCopy;
irVisitor.doSwitch(computeExpression);
proc.getBlocks().add(IrFactory.eINSTANCE.createBlockBasic());
for(Var var : varList) {
Instruction inst = getInstruction(var);
if(inst == null) {
return null;
}
proc.getLast().add(IrUtil.copy(inst));
}
InstAssign assign = IrFactory.eINSTANCE.createInstAssign(result,
computeExpression);
proc.getLast().add(assign);
proc.getLast().add(IrFactory.eINSTANCE.createInstReturn(
IrFactory.eINSTANCE.createExprVar(result)));
return proc;
}
}