/** * */ package soottocfg.cfg.optimization; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import org.jgrapht.Graphs; import com.google.common.base.Verify; import soottocfg.cfg.Program; import soottocfg.cfg.SourceLocation; import soottocfg.cfg.expression.Expression; import soottocfg.cfg.method.CfgBlock; import soottocfg.cfg.method.CfgEdge; import soottocfg.cfg.method.Method; import soottocfg.cfg.statement.AssignStatement; import soottocfg.cfg.statement.CallStatement; import soottocfg.cfg.statement.Statement; import soottocfg.cfg.type.ReferenceType; import soottocfg.cfg.variable.Variable; import soottocfg.soot.transformers.ArrayTransformer; /** * @author schaef * */ public class CfgCallInliner { private int freshInt = 0; /* * Note that these guys have to go from String->Integer, because * we have to use the method name since the hashcode of the method * changes while we inline. */ Map<String, Integer> totalStmts = new HashMap<String, Integer>(); Map<String, Integer> totalCallsTo = new HashMap<String, Integer>(); Set<String> alreadyInlined = new HashSet<String>(); final Program program; /** * */ public CfgCallInliner(Program p) { this.program = p; computeStats(p); } private void computeStats(Program p) { // Map<String, Integer> outgoingCalls = new HashMap<String, Integer>(); for (Method m : p.getMethods()) { if (!totalCallsTo.containsKey(m.getMethodName())) { totalCallsTo.put(m.getMethodName(), 0); } for (Method callee : calledMethods(m)) { if (!totalCallsTo.containsKey(callee.getMethodName())) { totalCallsTo.put(callee.getMethodName(), 0); } totalCallsTo.put(callee.getMethodName(), totalCallsTo.get(callee.getMethodName()) + 1); } int stmtCount = 0; for (CfgBlock b : m.vertexSet()) stmtCount += b.getStatements().size(); totalStmts.put(m.getMethodName(), stmtCount); } } private List<Method> calledMethods(Method m) { List<Method> res = new LinkedList<Method>(); for (CfgBlock b : m.vertexSet()) { for (Statement s : b.getStatements()) { if (s instanceof CallStatement) { CallStatement cs = (CallStatement) s; res.add(cs.getCallTarget()); } } } return res; } private Set<Method> reachableMethod(Method main) { Set<Method> reachable = new HashSet<Method>(); List<Method> todo = new LinkedList<Method>(); todo.add(main); while (!todo.isEmpty()) { Method m = todo.remove(0); reachable.add(m); for (Method n : calledMethods(m)) { if (!reachable.contains(n) && !todo.contains(n)) { todo.add(n); } } } return reachable; } public void inlineFromMain(int maxSize, int maxOccurences) { if (maxSize <= 0 && maxOccurences <= 0) { return; } Method mainMethod = program.getEntryPoint(); inlineCalls(mainMethod, maxSize, maxOccurences); FoldStraighLineSeq folder = new FoldStraighLineSeq(); folder.fold(mainMethod); Set<Method> reachable = reachableMethod(mainMethod); Set<Method> toRemove = new HashSet<Method>(); for (Method m : program.getMethods()) { if (!reachable.contains(m)) { toRemove.add(m); } } program.removeMethods(toRemove); // System.err.println(program); } private boolean canBeInlined(Method caller, Method callee) { boolean res = !callee.equals(caller) && !callee.isConstructor() && !callee.isStaticInitializer(); if (soottocfg.Options.v().arrayInv() && callee.getThisVariable() != null && ((ReferenceType) callee.getThisVariable().getType()).getClassVariable().getName().startsWith(ArrayTransformer.arrayTypeName)) { //TODO: for Rody's array model we must not inline array stuff. return false; } return res; } private void inlineCalls(Method method, int maxSize, int maxOccurences) { if (alreadyInlined.contains(method.getMethodName())) { return; } alreadyInlined.add(method.getMethodName()); enforceSingleInlineableCallPerBlock(method, maxSize, maxOccurences); List<CfgBlock> toRemove = new LinkedList<CfgBlock>(); for (CfgBlock b : new HashSet<CfgBlock>(method.vertexSet())) { for (Statement s : new LinkedList<Statement>(b.getStatements())) { if (s instanceof CallStatement) { CallStatement cs = (CallStatement) s; Method callee = cs.getCallTarget(); // first apply inlining to the callee inlineCalls(callee, maxSize, maxOccurences); if (canBeInlined(method, callee)) { if (totalCallsTo.get(callee.getMethodName()) < maxOccurences || totalStmts.get(callee.getMethodName()) < maxSize) { // now copy the callee into the caller. copyCalleeBody(method, b, cs); toRemove.add(b); /* * This is a bit hacky: we have to continue because * copyCalleeBody deletes the current block. * However, * we know from enforceSingleInlineableCallPerBlock * that there is only one call per block to inline. */ // break; } } } } } method.removeAllVertices(toRemove); } /** * Before inlining, we enforce that each block has at most one * call statement that can be inlined. * * @param method * @param maxSize * @param maxOccurences */ private void enforceSingleInlineableCallPerBlock(Method method, int maxSize, int maxOccurences) { for (CfgBlock b : new HashSet<CfgBlock>(method.vertexSet())) { splitBlockIfNecessary(method, b, maxSize, maxOccurences); } } /** * Splits blocks that contain more than one * inlineable call statement. * * @param m * @param b * @param maxSize * @param maxOccurences */ private void splitBlockIfNecessary(Method m, CfgBlock b, int maxSize, int maxOccurences) { int inlineableCalls = 0; for (Statement s : new LinkedList<Statement>(b.getStatements())) { if (s instanceof CallStatement) { CallStatement cs = (CallStatement) s; Method callee = cs.getCallTarget(); if (!callee.equals(m) && !callee.isConstructor() && !callee.isStaticInitializer()) { if (totalCallsTo.get(callee.getMethodName()) < maxOccurences || totalStmts.get(callee.getMethodName()) < maxSize) { inlineableCalls++; } } if (inlineableCalls > 1) { int idx = b.getStatements().indexOf(s); // split the block. List<Statement> rest = new LinkedList<Statement>( b.getStatements().subList(idx, b.getStatements().size())); b.removeStatements(rest); CfgBlock nextBlock = new CfgBlock(m); nextBlock.getStatements().addAll(rest); for (CfgBlock suc : Graphs.successorListOf(m, b)) { CfgEdge newEdge = new CfgEdge(); CfgEdge oldEdge = m.getEdge(b, suc); if (oldEdge.getLabel().isPresent()) { newEdge.setLabel(oldEdge.getLabel().get().deepCopy()); } m.removeEdge(b, suc); m.addEdge(nextBlock, suc, newEdge); } m.addEdge(b, nextBlock); splitBlockIfNecessary(m, nextBlock, maxSize, maxOccurences); return; } } } } /** * Copies the body of callee into caller at the * position of call. * * @param caller * @param block * @param call * @param callee */ private void copyCalleeBody(Method caller, CfgBlock block, CallStatement call) { /* * First remove block and replace it by two blocks * preBlock that contains all statements of block up to the call, and * postBlock that contains all statements of block after the call. */ int callIdx = block.getStatements().indexOf(call); CfgBlock preBlock = new CfgBlock(caller); preBlock.getStatements().addAll(block.getStatements().subList(0, callIdx)); CfgBlock postBlock = new CfgBlock(caller); if (callIdx + 1 < block.getStatements().size()) { postBlock.getStatements().addAll(block.getStatements().subList(callIdx + 1, block.getStatements().size())); } for (CfgBlock pre : new LinkedList<CfgBlock>(Graphs.predecessorListOf(caller, block))) { // copy the label over as well CfgEdge newEdge = new CfgEdge(); CfgEdge oldEdge = caller.getEdge(pre, block); if (oldEdge.getLabel().isPresent()) { newEdge.setLabel(oldEdge.getLabel().get().deepCopy()); } caller.removeEdge(pre, block); caller.addEdge(pre, preBlock, newEdge); } for (CfgBlock post : new LinkedList<CfgBlock>(Graphs.successorListOf(caller, block))) { CfgEdge newEdge = new CfgEdge(); CfgEdge oldEdge = caller.getEdge(block, post); if (oldEdge.getLabel().isPresent()) { newEdge.setLabel(oldEdge.getLabel().get().deepCopy()); } caller.removeEdge(block, post); caller.addEdge(postBlock, post, newEdge); } Verify.verify(caller.outDegreeOf(block) + caller.inDegreeOf(block) == 0); if (caller.getSource().equals(block)) { caller.setSource(preBlock); } else if (caller.getSink().equals(block)) { caller.setSink(postBlock); } // don't remove yet. Otherwise the numbering of the blocks // gets all messed up. remove in inlineCalls Method callee = call.getCallTarget(); Verify.verifyNotNull(callee.getSource()); /* * Create a map from callee locals and formals to fresh caller locals. */ Map<Variable, Variable> varSubstitionMap = new HashMap<Variable, Variable>(); List<Variable> toCopy = new LinkedList<Variable>(); toCopy.addAll(callee.getInParams()); toCopy.addAll(callee.getOutParams()); toCopy.addAll(callee.getLocals()); for (Variable v : toCopy) { Variable local = new Variable("cp_" + v.getName() + "_" + (++freshInt), v.getType()); caller.addLocalVariable(local); varSubstitionMap.put(v, local); } SourceLocation loc = call.getSourceLocation(); for (int i = 0; i < callee.getInParams().size(); i++) { Variable v = callee.getInParam(i); preBlock.addStatement( new AssignStatement(loc, varSubstitionMap.get(v).mkExp(loc), call.getArguments().get(i))); } /* * Add call reachable blocks from the callee to the caller */ Map<CfgBlock, CfgBlock> cloneMap = new HashMap<CfgBlock, CfgBlock>(); for (CfgBlock cur : callee.vertexSet()) { CfgBlock clone = new CfgBlock(caller); for (Statement s : cur.getStatements()) { clone.addStatement(s.substitute(varSubstitionMap)); } cloneMap.put(cur, clone); } /* * Add all edges for the copied blocks. */ for (CfgEdge edge : callee.edgeSet()) { CfgBlock src = callee.getEdgeSource(edge); CfgBlock tgt = callee.getEdgeTarget(edge); if (cloneMap.containsKey(src) && cloneMap.containsKey(tgt)) { CfgEdge newEdge = new CfgEdge(); if (edge.getLabel().isPresent()) { newEdge.setLabel(edge.getLabel().get().substitute(varSubstitionMap)); // newEdge.setLabel(edge.getLabel().get()); } caller.addEdge(cloneMap.get(src), cloneMap.get(tgt), newEdge); } } /* * Connect the copies blocks with the caller. */ caller.addEdge(preBlock, cloneMap.get(callee.getSource())); if (cloneMap.containsKey(callee.getSink())) { // if callee loops forever, this might not be reached. caller.addEdge(cloneMap.get(callee.getSink()), postBlock); } /* * Now update the out variables */ for (int i = 0; i < callee.getOutParams().size(); i++) { Expression receiver; if (i < call.getReceiver().size()) { receiver = call.getReceiver().get(i); postBlock.addStatement(0, new AssignStatement(loc, receiver, varSubstitionMap.get(callee.getOutParams().get(i)).mkExp(loc))); } else { System.err.println("More outparams than receivers " + call); } } } }