package xsched.wala.optimizations; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.Map.Entry; import com.ibm.wala.classLoader.IClass; import com.ibm.wala.classLoader.IMethod; import com.ibm.wala.ipa.callgraph.CGNode; import com.ibm.wala.ipa.callgraph.CallGraph; import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; import com.ibm.wala.ipa.cha.ClassHierarchy; import com.ibm.wala.ssa.IR; import com.ibm.wala.ssa.SSAInstruction; import com.ibm.wala.ssa.SSAMonitorInstruction; import xsched.analysis.core.AnalysisResult; import xsched.analysis.wala.WalaScheduleAnalysisDriver; public class SynchronizationRemovalOptimization { private final AnalysisResult<CGNode> schedule; private final WalaScheduleAnalysisDriver driver; public SynchronizationRemovalOptimization(WalaScheduleAnalysisDriver driver) { this.driver = driver; this.schedule = driver.scheduleAnalysisResult(); } private void computeRequiredVariablesForNode(GlobalPointsToInfo globalInfo, CGNode node, Set<InstanceKey> usedInParallel, Map<CGNode, Set<Variable>> out) { Set<Variable> variables = out.get(node); if(variables == null) { variables = new HashSet<Variable>(); out.put(node, variables); } LocalPointsToInfo localInfo = globalInfo.pointsToSet(node); for(Entry<Variable, Set<InstanceKey>> entry : localInfo.info().entrySet()) { if(Util.containsAny(usedInParallel, entry.getValue())) { variables.add(entry.getKey()); } } } public Map<CGNode, Set<Variable>> computeRequiredSyncPointsByCGNode(Reachability<CGNode, CGNode> reachability) { //step 1: find all variables that are used for syncing Set<Variable> syncVariables = collectSynchronizationVariables(); //step 2: compute where those variables may point to (context dependent) GlobalPointsToInfo globalPointsToInfo = GlobalPointsToInfo.make(driver.pointerAnalysis(), driver.callGraph(), syncVariables); //step 3: for each task node, collect all the reachable non task nodes and collect their variable points to sets; so kinda "flatten" the call graph into the task nodes Map<CGNode, LocalPointsToInfo> syncedVariablesByTask = globalPointsToInfo.collectLocalPointsToSetsForTasks(reachability, schedule.tasks()); //we don't want the whole variable points to set info but are happy with just flattening everything into a set of instance keys per task //at the same time, filter out all non escaping instances Map<CGNode, Set<InstanceKey>> syncedInstancesByTask = Util.mapToInstanceKeys(syncedVariablesByTask, driver.escapeAnalysisResult()); //per task, find what instance keys are also used by other parallel tasks Map<CGNode, Set<InstanceKey>> syncedInParallelByTask = Util.collectInstanceKeysUsedInParallel(schedule, syncedInstancesByTask); //now we want to "undo" all the folding from above and find for each CGNode reachable by a task what variables it has to keep //this could also be mapped from the CG node to the method to get a more "global" but less precise view Map<CGNode, Set<Variable>> requiredSyncsByCGNode = new HashMap<CGNode, Set<Variable>>(); for(Entry<CGNode, Set<InstanceKey>> entry : syncedInParallelByTask.entrySet()) { CGNode task = entry.getKey(); Set<InstanceKey> usedInParallel = entry.getValue(); computeRequiredVariablesForNode(globalPointsToInfo, task, usedInParallel, requiredSyncsByCGNode); //then do the same for other nodes for(CGNode nonTaskNode : reachability.nonTaskNodesReachableByTask(task)) { computeRequiredVariablesForNode(globalPointsToInfo, nonTaskNode, usedInParallel, requiredSyncsByCGNode); } } return requiredSyncsByCGNode; } private Set<Variable> collectSynchronizationVariables() { ClassHierarchy classHierarchy = driver.classHierarchy(); Set<Variable> syncPoints = new HashSet<Variable>(); //step 1: iterate all classes and find synchronization points in their methods for(IClass cls : classHierarchy) { Collection<IMethod> methods = cls.getDeclaredMethods(); for(IMethod method : methods) { collectSynchronizationSSAVariables(syncPoints, method, driver.irForMethod(method)); } } return syncPoints; } private void collectSynchronizationSSAVariables(Set<Variable> syncPoints, IMethod method, IR ir) { if(method.isSynchronized()) { if(method.isStatic()) { syncPoints.add(new Variable(method, Variable.CLASS)); } else { syncPoints.add(new Variable(method, Variable.THIS)); } } for(SSAInstruction instruction : ir.getInstructions()) { if(instruction instanceof SSAMonitorInstruction) { SSAMonitorInstruction monitor = (SSAMonitorInstruction)instruction; if(monitor.isMonitorEnter()) { syncPoints.add(new Variable(method, monitor.getRef())); } } } } }