package soottocfg.soot.memory_model; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Queue; import java.util.Set; import com.google.common.base.Verify; import soottocfg.cfg.Program; import soottocfg.cfg.expression.BinaryExpression; import soottocfg.cfg.expression.Expression; import soottocfg.cfg.expression.IdentifierExpression; import soottocfg.cfg.expression.literal.NullLiteral; import soottocfg.cfg.method.CfgBlock; import soottocfg.cfg.method.CfgEdge; import soottocfg.cfg.method.Method; import soottocfg.cfg.statement.AssertStatement; import soottocfg.cfg.statement.AssignStatement; import soottocfg.cfg.statement.AssumeStatement; import soottocfg.cfg.statement.CallStatement; import soottocfg.cfg.statement.NewStatement; import soottocfg.cfg.statement.PullStatement; import soottocfg.cfg.statement.PushStatement; import soottocfg.cfg.statement.Statement; import soottocfg.cfg.type.ReferenceType; import soottocfg.cfg.variable.ClassVariable; import soottocfg.cfg.variable.Variable; import soottocfg.soot.SootToCfg; public class PushPullSimplifier { private static boolean debug = false; public boolean simplify(Program p) { boolean change = false; Method[] ms = p.getMethods(); for (Method m : ms) { if (debug) { System.out.println("Simplifying method " + m.getMethodName()); System.out.println(m); } Set<CfgBlock> blocks = m.vertexSet(); int simplifications; do { // intra-block simplification for (CfgBlock block : blocks) { change = simplify(block) ? true : change; } // inter-block simplifications simplifications = 0; simplifications += movePullsUpInCFG(m); simplifications += movePushesDownInCFG(m); change = (simplifications>0) ? true : change; } while (simplifications > 0); if (debug) System.out.println("SIMPLIFIED:\n"+m); } return change; } public boolean simplify(CfgBlock b) { boolean change = false; int simplifications; do { simplifications = 0; simplifications += removeConseqPulls(b); simplifications += removeConseqPushs(b); simplifications += removePullAfterPush(b); simplifications += removePushAfterPull(b); simplifications += movePullUp(b); simplifications += movePushDown(b); simplifications += swapPushPull(b); simplifications += orderPulls(b); simplifications += orderPushes(b); // simplifications += assumeFalseEatPreceeding(b); if (simplifications>0) { change = true; } } while (simplifications > 0); return change; } /* Rule I */ private int removeConseqPulls(CfgBlock b) { int removed = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if ( (stmts.get(i) instanceof PullStatement || isConstructorCall(stmts.get(i))) && stmts.get(i+1) instanceof PullStatement) { Statement pull1 = stmts.get(i); Statement pull2 = stmts.get(i+1); if (getObject(pull1).sameVariable(getObject(pull2))) { if (debug) System.out.println("Applied rule (I); removed " + pull2); b.removeStatement(pull2); removed++; } } } return removed; } /* Rule II */ private int removeConseqPushs(CfgBlock b) { int removed = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if (stmts.get(i) instanceof PushStatement && stmts.get(i+1) instanceof PushStatement) { PushStatement push1 = (PushStatement) stmts.get(i); PushStatement push2 = (PushStatement) stmts.get(i+1); if (getObject(push1).sameVariable(getObject(push2))) { if (debug) System.out.println("Applied rule (II); removed " + push1); b.removeStatement(push1); removed++; } } } return removed; } /* Rule III */ private int removePullAfterPush(CfgBlock b) { int removed = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if (stmts.get(i) instanceof PushStatement && stmts.get(i+1) instanceof PullStatement) { PushStatement push = (PushStatement) stmts.get(i); PullStatement pull = (PullStatement) stmts.get(i+1); if (sameVars(push,pull)) { if (debug) System.out.println("Applied rule (III); removed " + pull); b.removeStatement(pull); removed++; } } } return removed; } /* Rule IV */ private int removePushAfterPull(CfgBlock b) { int removed = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if ( (stmts.get(i) instanceof PullStatement || isConstructorCall(stmts.get(i))) && stmts.get(i+1) instanceof PushStatement) { Statement pull = stmts.get(i); Statement push = stmts.get(i+1); if (sameVarsStatement(push,pull)) { if (debug) System.out.println("Applied rule (IV); removed " + push); b.removeStatement(push); removed++; } } } return removed; } /* Rule V */ private int movePullUp(CfgBlock b) { int moved = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if (stmts.get(i+1) instanceof PullStatement || isConstructorCall(stmts.get(i+1))) { Statement pull = stmts.get(i+1); Statement s = stmts.get(i); if (s instanceof AssignStatement) { Set<IdentifierExpression> pullvars = pull.getIdentifierExpressions(); pullvars.addAll(pull.getDefIdentifierExpressions()); AssignStatement as = (AssignStatement) s; Set<IdentifierExpression> svars = as.getLeft().getUseIdentifierExpressions(); if (distinct(svars,pullvars)) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (V); swapped " + s + " and " + pull); moved++; } } else if (s instanceof NewStatement || s instanceof AssumeStatement) { Set<IdentifierExpression> pullvars = pull.getIdentifierExpressions(); pullvars.addAll(pull.getDefIdentifierExpressions()); Set<IdentifierExpression> svars = s.getUseIdentifierExpressions(); svars.addAll(s.getDefIdentifierExpressions()); if (distinct(svars,pullvars)) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (V); swapped " + s + " and " + pull); moved++; } } else if (s instanceof AssertStatement) { // do not move past null check AssertStatement as = (AssertStatement) s; if (i == 0 || (!(pull instanceof PullStatement)) || !isNullCheckBeforePull(stmts.get(i-1), as, (PullStatement) pull)) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (V); swapped " + s + " and " + pull); moved++; } } } } return moved; } /* Rule VI */ private int movePushDown(CfgBlock b) { int moved = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if (stmts.get(i) instanceof PushStatement) { PushStatement push = (PushStatement) stmts.get(i); Statement s = stmts.get(i+1); if (s instanceof AssignStatement || s instanceof AssertStatement || s instanceof NewStatement || s instanceof AssumeStatement) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (VI); swapped " + push + " and " + s); moved++; } } } return moved; } /* Rule VII */ private int swapPushPull(CfgBlock b) { int swapped = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if ( stmts.get(i) instanceof PushStatement && (stmts.get(i+1) instanceof PullStatement || isConstructorCall(stmts.get(i+1)))) { Statement push = stmts.get(i); Statement pull = stmts.get(i+1); //only swap if the objects in the pull and push do not point to the same location if (!SootToCfg.getPointsToAnalysis().mayAlias(getObject(pull), getObject(push))) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (VII); swapped " + push + " and " + pull); swapped++; } } } return swapped; } /* Rule VIII */ private int orderPulls(CfgBlock b) { // order pushes alphabetically w.r.t. the object name // allows to remove doubles int swapped = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if ( (stmts.get(i) instanceof PullStatement || isConstructorCall(stmts.get(i))) && (stmts.get(i+1) instanceof PullStatement || isConstructorCall(stmts.get(i+1)))) { Statement pull1 = stmts.get(i); Statement pull2 = stmts.get(i+1); if (getObject(pull1).toString().compareTo(getObject(pull2).toString()) < 0) { //only swap if none of the vars in the pull and push point to the same location Set<IdentifierExpression> pull1vars = pull1.getIdentifierExpressions(); Set<IdentifierExpression> pull2vars = pull2.getIdentifierExpressions(); if (distinct(pull1vars,pull2vars)) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (VIII); swapped " + pull1 + " and " + pull2); swapped++; } } } } return swapped; } /* Rule IX */ private int orderPushes(CfgBlock b) { // order pushes alphabetically w.r.t. the object name // allows to remove doubles int swapped = 0; List<Statement> stmts = b.getStatements(); for (int i = 0; i+1 < stmts.size(); i++) { if (stmts.get(i) instanceof PushStatement && stmts.get(i+1) instanceof PushStatement) { PushStatement push1 = (PushStatement) stmts.get(i); PushStatement push2 = (PushStatement) stmts.get(i+1); if (push1.getObject().toString().compareTo(push2.getObject().toString()) > 0) { //only swap if none of the vars in the pull and push point to the same location Set<IdentifierExpression> push1vars = push1.getIdentifierExpressions(); Set<IdentifierExpression> push2vars = push2.getIdentifierExpressions(); if (distinct(push1vars,push2vars)) { b.swapStatements(i, i+1); if (debug) System.out.println("Applied rule (IX); swapped " + push1 + " and " + push2); swapped++; } } } } return swapped; } /* Rule X (new) */ // private int assumeFalseEatPreceeding(CfgBlock b) { // int eaten = 0; // List<Statement> stmts = b.getStatements(); // for (int i = 0; i < stmts.size(); i++) { // if (stmts.get(i) instanceof AssumeStatement) { // AssumeStatement as = (AssumeStatement) stmts.get(i); // if (as.getExpression() instanceof BooleanLiteral && // ((BooleanLiteral) as.getExpression()).equals(BooleanLiteral.falseLiteral())) { // //Found one! Now eat everything except asserts. // Set<Statement> toRemove = new HashSet<Statement>(); // int j = i - 1; // while (j >= 0 && !(stmts.get(j) instanceof AssertStatement)) { // System.out.println("Assume(false) eating " + stmts.get(j)); // toRemove.add(stmts.get(j)); // j--; // } // b.removeStatements(toRemove); // } // } // } // return eaten; // } private int movePullsUpInCFG(Method m) { int moves = 0; for (CfgBlock b : m.vertexSet()) { if (debug) System.out.println("Checking block " + b.getLabel() + " for pulls to move up"); List<Statement> stmts = b.getStatements(); int s = 0; Set<Statement> toRemove = new HashSet<Statement>(); while (s < stmts.size() && (stmts.get(s) instanceof PullStatement || isConstructorCall(stmts.get(s)))) { Statement pull = stmts.get(s); Set<CfgEdge> incoming = b.getMethod().incomingEdgesOf(b); Set<CfgBlock> moveTo = new HashSet<CfgBlock>(); boolean nothingMoves = false; if (debug) System.out.println("Let's see if we can move " + pull + " up in the CFG..."); for (CfgEdge in : incoming) { CfgBlock prev = b.getMethod().getEdgeSource(in); // only move up in CFG if (m.distanceToSource(prev) < m.distanceToSource(b)) { if (in.getLabel().isPresent() && !distinct(in.getLabel().get().getUseIdentifierExpressions(), pull.getIdentifierExpressions())) { // edge label contains a ref to push object, do not move this push if (debug) System.out.println("Label not distinct: " + pull); nothingMoves = true; break; } moveTo.add(prev); } else if (!existsPath(b, prev)) { /** * With the check that there does not exist a path, we accommodate the following case. * There is a loop and b is the header. We know this because the is a path from b * to prev, yet b is closer to the source. * In this case, we want to move the pull out of the loop into the other * predecessor nodes (note that these surely exist), yet not move it into * prev, as we would then be moving it around in circles. */ nothingMoves = true; break; } else if (debug) { System.out.println("Skipping the copy of " + pull + " from " + b + " to " + prev); } } if (!nothingMoves) { for (CfgBlock prev : moveTo) { //don't create references to the same statement in multiple blocks if (toRemove.contains(pull)) pull = pull.deepCopy(); else toRemove.add(pull); prev.addStatement(pull); moves++; if (debug) System.out.println("Moved " + pull + " up in CFG."); } } s++; } b.removeStatements(toRemove); } return moves; } private int movePushesDownInCFG(Method m) { int moves = 0; for (CfgBlock b : m.vertexSet()) { if (debug) System.out.println("Checking block " + b.getLabel() + " for pushes to move down"); List<Statement> stmts = b.getStatements(); int s = stmts.size()-1; Set<Statement> toRemove = new HashSet<Statement>(); while (s >= 0 && stmts.get(s) instanceof PushStatement) { PushStatement push = (PushStatement) stmts.get(s); Set<CfgEdge> outgoing = b.getMethod().outgoingEdgesOf(b); Set<CfgBlock> moveTo = new HashSet<CfgBlock>(); boolean nothingMoves = false; if (debug) System.out.println("Let's see if we can move " + push + " down in the CFG..."); for (CfgEdge out : outgoing) { CfgBlock next = b.getMethod().getEdgeTarget(out); // only move down in source if (m.distanceToSink(next) < m.distanceToSink(b)) { if (out.getLabel().isPresent() && !distinct(out.getLabel().get().getUseIdentifierExpressions(), push.getIdentifierExpressions())) { // edge label contains a ref to push object, do not move this push if (debug) System.out.println("Label not distinct: " + push); nothingMoves = true; break; } if (!hasBeenPulledIn(push, next)) { // object has not been pulled in successor block, do not move this push if (debug) System.out.println("Not pulled in predecessor of block " + next.getLabel() + ": " + push); nothingMoves = true; break; } moveTo.add(next); } else if (!existsPath(next, b)) { /** * Analogous to moving pulls up, we want to break pushes * from loops. */ nothingMoves = true; } else if (debug) { System.out.println("Skipping the copy of " + push + " from " + b + " to " + next); } } if (!nothingMoves) { for (CfgBlock next : moveTo) { // don't create references to the same statement in multiple blocks if (toRemove.contains(push)) push = (PushStatement) push.deepCopy(); else toRemove.add(push); next.addStatement(0, push); moves++; if (debug) System.out.println("Moved " + push + " down in CFG."); } } s--; } b.removeStatements(toRemove); } return moves; } private boolean distinct(Set<IdentifierExpression> vars1, Set<IdentifierExpression> vars2) { if (debug) System.out.println("Checking distinctness of " + vars1 + " and " + vars2); for (IdentifierExpression exp1 : vars1) { for (IdentifierExpression exp2 : vars2) { if (debug) System.out.println("Checking distinctness of " + exp1 + exp1.getType() + " and " + exp2 + exp2.getType()); if (exp1.sameVariable(exp2)) { if (debug) System.out.println("Same var: " + exp1 + " and " + exp2); return false; } else if (exp1.getType() instanceof ReferenceType && exp2.getType() instanceof ReferenceType) { if (soottocfg.Options.v().memPrecision() >= 3) { if (SootToCfg.getPointsToAnalysis().mayAlias(exp1, exp2)) return false; } else { ReferenceType rt1 = (ReferenceType) exp1.getType(); ReferenceType rt2 = (ReferenceType) exp2.getType(); ClassVariable cv1 = rt1.getClassVariable(); ClassVariable cv2 = rt2.getClassVariable(); if (cv1!=null && cv2!=null && (cv1.subclassOf(cv2) || !cv1.superclassOf(cv2))) return false; } } } } return true; } private boolean sameVars(PushStatement push, PullStatement pull) { List<Expression> pushvars = push.getRight(); List<IdentifierExpression> pullvars = pull.getLeft(); if (pushvars.size() != pullvars.size()) return false; for (int i = 0; i < pushvars.size(); i++) { if (! (pushvars.get(i) instanceof IdentifierExpression)) return false; IdentifierExpression ie1 = (IdentifierExpression) pullvars.get(i); IdentifierExpression ie2 = (IdentifierExpression) pushvars.get(i); if (!ie1.sameVariable(ie2)) return false; } return true; } private boolean sameVars(PushStatement push, CallStatement pull) { List<Expression> pushvars = push.getRight(); List<Expression> pullvars = pull.getReceiver(); if (pushvars.size() != pullvars.size()) return false; for (int i = 0; i < pushvars.size(); i++) { if (! (pullvars.get(i) instanceof IdentifierExpression)) return false; if (! (pushvars.get(i) instanceof IdentifierExpression)) return false; IdentifierExpression ie1 = (IdentifierExpression) pullvars.get(i); IdentifierExpression ie2 = (IdentifierExpression) pushvars.get(i); if (!ie1.sameVariable(ie2)) return false; } return true; } private boolean sameVarsStatement(Statement push, Statement pull) { Verify.verify(push instanceof PushStatement); Verify.verify(pull instanceof PullStatement || isConstructorCall(pull)); if (pull instanceof PullStatement) return sameVars((PushStatement) push, (PullStatement) pull); else return sameVars((PushStatement) push, (CallStatement) pull); } private boolean isConstructorCall(Statement s) { if (s instanceof CallStatement) { CallStatement cs = (CallStatement) s; return cs.getCallTarget().isConstructor(); } return false; } private IdentifierExpression getObject(Statement s) { if (s instanceof PullStatement) return (IdentifierExpression) ((PullStatement) s).getObject(); if (s instanceof PushStatement) return (IdentifierExpression) ((PushStatement) s).getObject(); if (isConstructorCall(s)) return (IdentifierExpression) ((CallStatement) s).getArguments().get(0); return null; } // check if the object of a push has been pulled in or on the path to CfgBlock b private boolean hasBeenPulledIn(PushStatement push, CfgBlock b) { Set<CfgBlock> done = new HashSet<CfgBlock>(); Queue<CfgBlock> q = new LinkedList<CfgBlock>(); q.add(b); while (!q.isEmpty()) { CfgBlock cur = q.poll(); done.add(cur); for (Statement s : cur.getStatements()) { if (s instanceof PullStatement || isConstructorCall(s)) { if (getObject(s).sameVariable(getObject(push))) { return true; } } } Set<CfgEdge> incoming = cur.getMethod().incomingEdgesOf(cur); for (CfgEdge in : incoming) { CfgBlock prev = cur.getMethod().getEdgeSource(in); if (!done.contains(prev) && !q.contains(prev)) q.add(prev); } } return false; } private boolean existsPath(CfgBlock from, CfgBlock to) { Set<CfgBlock> done = new HashSet<CfgBlock>(); Queue<CfgBlock> q = new LinkedList<CfgBlock>(); q.add(from); while (!q.isEmpty()) { CfgBlock cur = q.poll(); done.add(cur); if (cur.equals(to)) return true; Set<CfgEdge> outgoing = cur.getMethod().outgoingEdgesOf(cur); for (CfgEdge out : outgoing) { CfgBlock next = cur.getMethod().getEdgeTarget(out); if (!done.contains(next) && !q.contains(next)) q.add(next); } } return false; } private boolean isNullCheckBeforePull(Statement previous, AssertStatement as, PullStatement pull) { Variable pullVar = ((IdentifierExpression) pull.getObject()).getVariable(); if (previous instanceof AssignStatement) { AssignStatement assign = (AssignStatement) previous; Expression rhs = assign.getRight(); if (rhs instanceof BinaryExpression) { BinaryExpression be = (BinaryExpression) rhs; if (be.getOp() == BinaryExpression.BinaryOperator.Ne) { if (be.getRight() instanceof NullLiteral && be.getLeft() instanceof IdentifierExpression) { IdentifierExpression ie = (IdentifierExpression) be.getLeft(); for (Variable v : pull.getAllVariables()) { if (v.equals(pullVar)) { if (debug) System.out.println("Found null check for " + ie); return true; } } } } } } return false; } }