/* Copyright (c) 2006, Sriram Srinivasan * * You may distribute this software under the terms of the license * specified in the file "License" */ package kilim.analysis; import static kilim.Constants.D_FIBER; import static kilim.Constants.D_INT; import static kilim.Constants.D_VOID; import static kilim.Constants.FIBER_CLASS; import static kilim.Constants.TASK_CLASS; import static kilim.analysis.VMType.TOBJECT; import static kilim.analysis.VMType.loadVar; import static org.objectweb.asm.Opcodes.ALOAD; import static org.objectweb.asm.Opcodes.ASTORE; import static org.objectweb.asm.Opcodes.DUP; import static org.objectweb.asm.Opcodes.GETFIELD; import static org.objectweb.asm.Opcodes.GOTO; import static org.objectweb.asm.Opcodes.INVOKESTATIC; import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL; import static org.objectweb.asm.Opcodes.RETURN; import java.util.ArrayList; import java.util.List; import kilim.Constants; import org.objectweb.asm.AnnotationVisitor; import org.objectweb.asm.Attribute; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.LabelNode; import org.objectweb.asm.tree.LocalVariableNode; import org.objectweb.asm.tree.LookupSwitchInsnNode; import org.objectweb.asm.tree.TableSwitchInsnNode; import org.objectweb.asm.tree.TryCatchBlockNode; /** * This class takes the basic blocks from a MethodFlow and generates * all the extra code to support continuations. */ public class MethodWeaver { private ClassWeaver classWeaver; private MethodFlow methodFlow; private boolean isPausable; private int maxVars; private int maxStack; /** * The last parameter to a pausable method is a Fiber ref. The rest of the * code doesn't know this because we do local surgery, and so is likely to * stomp on the corresponding local var. We need to save this in a slot * beyond (the original) maxLocals that is a safe haven for keeping the * fiberVar. */ private int fiberVar; private int numWordsInSig; private ArrayList<CallWeaver> callWeavers = new ArrayList<CallWeaver>(5); MethodWeaver(ClassWeaver cw, MethodFlow mf) { this.classWeaver = cw; this.methodFlow = mf; isPausable = mf.isPausable(); fiberVar = methodFlow.maxLocals; maxVars = fiberVar + 1; maxStack = methodFlow.maxStack + 1; // plus Fiber if (!mf.isAbstract()) { createCallWeavers(); } } public void accept(ClassVisitor cv) { MethodFlow mf = methodFlow; String[] exceptions = ClassWeaver.toStringArray(mf.exceptions); String desc = mf.desc; String sig = mf.signature; if (mf.isPausable()) { desc = desc.replace(")", D_FIBER + ')'); if (sig != null) sig = sig.replace(")", D_FIBER + ')'); } MethodVisitor mv = cv.visitMethod(mf.access, mf.name, desc, sig, exceptions); if (!mf.isAbstract()) { if (mf.isPausable()) { accept(mv); } else { mf.accept(mv); } } else { mv.visitEnd(); } } void accept(MethodVisitor mv) { visitAttrs(mv); visitCode(mv); mv.visitEnd(); } private void visitAttrs(MethodVisitor mv) { MethodFlow mf = methodFlow; // visits the method attributes int i, j, n; if (mf.annotationDefault != null) { AnnotationVisitor av = mv.visitAnnotationDefault(); MethodFlow.acceptAnnotation(av, null, mf.annotationDefault); av.visitEnd(); } n = mf.visibleAnnotations == null ? 0 : mf.visibleAnnotations.size(); for (i = 0; i < n; ++i) { AnnotationNode an = (AnnotationNode) mf.visibleAnnotations.get(i); an.accept(mv.visitAnnotation(an.desc, true)); } n = mf.invisibleAnnotations == null ? 0 : mf.invisibleAnnotations.size(); for (i = 0; i < n; ++i) { AnnotationNode an = (AnnotationNode) mf.invisibleAnnotations.get(i); an.accept(mv.visitAnnotation(an.desc, false)); } n = mf.visibleParameterAnnotations == null ? 0 : mf.visibleParameterAnnotations.length; for (i = 0; i < n; ++i) { List<?> l = mf.visibleParameterAnnotations[i]; if (l == null) { continue; } for (j = 0; j < l.size(); ++j) { AnnotationNode an = (AnnotationNode) l.get(j); an.accept(mv.visitParameterAnnotation(i, an.desc, true)); } } n = mf.invisibleParameterAnnotations == null ? 0 : mf.invisibleParameterAnnotations.length; for (i = 0; i < n; ++i) { List<?> l = mf.invisibleParameterAnnotations[i]; if (l == null) { continue; } for (j = 0; j < l.size(); ++j) { AnnotationNode an = (AnnotationNode) l.get(j); an.accept(mv.visitParameterAnnotation(i, an.desc, false)); } } n = mf.attrs == null ? 0 : mf.attrs.size(); for (i = 0; i < n; ++i) { mv.visitAttribute((Attribute) mf.attrs.get(i)); } } private void visitCode(MethodVisitor mv) { mv.visitCode(); methodFlow.resetLabels(); visitTryCatchBlocks(mv); visitInstructions(mv); visitLocals(mv); visitLineNumbers(mv); mv.visitMaxs(maxStack, maxVars); } private void visitLineNumbers(MethodVisitor mv) { methodFlow.visitLineNumbers(mv); } private void visitLocals(MethodVisitor mv) { for (Object l: methodFlow.localVariables) { ((LocalVariableNode)l).accept(mv); } } private void visitInstructions(MethodVisitor mv) { MethodFlow mf = methodFlow; genPrelude(mv); BasicBlock lastBB = null; for (BasicBlock bb : mf.getBasicBlocks()) { int from = bb.startPos; if (bb.isPausable() && bb.startFrame != null) { genPausableMethod(mv, bb); from = bb.startPos + 1; // first instruction is consumed } else if (bb.isCatchHandler()) { List<CallWeaver> cwList = getCallsUnderCatchBlock(bb); if (cwList != null) { genException(mv, bb, cwList); from = bb.startPos + 1; // first instruction is consumed } // else no different from any other block } int to = bb.endPos; for (int i = from; i <= to; i++) { LabelNode l = mf.getLabelAt(i); if (l != null) { l.accept(mv); } bb.getInstruction(i).accept(mv); } lastBB = bb; } if (lastBB != null) { LabelNode l = methodFlow.getLabelAt(lastBB.endPos+1); if (l != null) { l.accept(mv); } } } private List<CallWeaver> getCallsUnderCatchBlock(BasicBlock catchBB) { List<CallWeaver> cwList = null; // create it lazily for (CallWeaver cw: callWeavers) { for (Handler h: cw.bb.handlers) { if (h.catchBB == catchBB) { if (cwList == null) { cwList = new ArrayList<CallWeaver>(callWeavers.size()); } if (!cwList.contains(cw)) { cwList.add(cw); } } } } return cwList; } /** * For a method invocation f(...), this method assumes that the arguments to * the call have already been pushed in. We need to push in the Fiber as the * final argument, make the call, then add the code for post-calls, then * leave it to visitInstructions() to resume visiting the remaining * instructions in the block * * <pre> * F_CALL: * aload <fiberVar> * invokevirtual fiber.down() ;; returns Fiber * ... invoke .... * aload <fiberVar> * ... post call code * F_RESUME: * </pre> * * @param bb * The BasicBlock that contains the pausable method invocation as the first * instruction * @param mv */ private void genPausableMethod(MethodVisitor mv, BasicBlock bb) { CallWeaver caw = null; if (bb.isGetCurrentTask()) { genGetCurrentTask(mv, bb); return; } for (CallWeaver cw : callWeavers) { if (cw.getBasicBlock() == bb) { caw = cw; break; } } caw.genCall(mv); caw.genPostCall(mv); } /* * The Task.getCurrentTask() method is marked pausable to force * the caller to be pausable too. But the method doesn't really * pause; it merely looks up the task from the fiber. This is a * special case where the call to getCurrentTask is replaced by * <pre> * load fiberVar * getfield task * @param mv */ void genGetCurrentTask(MethodVisitor mv, BasicBlock bb) { bb.startLabel.accept(mv); loadVar(mv, TOBJECT, getFiberVar()); mv.visitFieldInsn(GETFIELD, FIBER_CLASS, "task", Constants.D_TASK); } private boolean hasGetCurrentTask() { MethodFlow mf = methodFlow; for (BasicBlock bb : mf.getBasicBlocks()) { if (!bb.isPausable() || bb.startFrame==null) continue; if (bb.isGetCurrentTask()) return true; } return false; } private void createCallWeavers() { MethodFlow mf = methodFlow; for (BasicBlock bb : mf.getBasicBlocks()) { if (!bb.isPausable() || bb.startFrame==null) continue; // No prelude needed for Task.getCurrentTask(). if (bb.isGetCurrentTask()) continue; CallWeaver cw = new CallWeaver(this, bb); callWeavers.add(cw); } } /** * * Say there are two invocations to two pausable methods obj.f(int) * (virtual) and fs(double) (a static call) ; load fiber from last arg, and * save it in a fresh register ; lest it gets stomped on. This is because we * only patch locally, and don't change the other instructions. * * <pre> * aload lastVar * dup * astore fiberVar * switch (fiber.pc) { * default: 0: START * 1: F_PASS_DOWN * 2: FS_PASS_DOWN * } * </pre> */ private void genPrelude(MethodVisitor mv) { assert isPausable : "MethodWeaver.genPrelude called for nonPausable method"; if (callWeavers.size() == 0 && (!hasGetCurrentTask())) { // Method has been marked pausable, but does not call any pausable methods, nor Task.getCurrentTask. // Prelude is not needed at all. return; } MethodFlow mf = methodFlow; // load fiber from last var int lastVar = getFiberArgVar(); mv.visitVarInsn(ALOAD, lastVar); if (lastVar < fiberVar) { if (callWeavers.size() > 0) { mv.visitInsn(DUP); // for storing into fiberVar } mv.visitVarInsn(ASTORE, getFiberVar()); } if (callWeavers.size() == 0) { // No pausable method calls, but Task.getCurrentTask() is present. // We don't need the rest of the prelude. return; } mv.visitFieldInsn(GETFIELD, FIBER_CLASS, "pc", D_INT); // The prelude doesn't need more than two words in the stack. // The callweaver gen* methods may need more. ensureMaxStack(2); // switch stmt LabelNode startLabel = mf.getOrCreateLabelAtPos(0); LabelNode errLabel = new LabelNode(); LabelNode[] labels = new LabelNode[callWeavers.size() + 1]; labels[0] = startLabel; for (int i = 0; i < callWeavers.size(); i++) { labels[i + 1] = new LabelNode(); } new TableSwitchInsnNode(0, callWeavers.size(), errLabel, labels).accept(mv); errLabel.accept(mv); mv.visitVarInsn(ALOAD, getFiberVar()); mv.visitMethodInsn(INVOKEVIRTUAL, FIBER_CLASS, "wrongPC", "()V"); // Generate pass through down code, one for each pausable method // invocation int last = callWeavers.size() - 1; for (int i = 0; i <= last; i++) { CallWeaver cw = callWeavers.get(i); labels[i+1].accept(mv); cw.genRewind(mv); } startLabel.accept(mv); } boolean isStatic() { return methodFlow.isStatic(); } int getFiberArgVar() { int lastVar = getNumWordsInSig(); if (!isStatic()) { lastVar++; } return lastVar; } /* * The number of words in the argument; doubles/longs occupy * two local vars. */ int getNumWordsInSig() { if (numWordsInSig != -1) { String[]args = TypeDesc.getArgumentTypes(methodFlow.desc); int size = 0; for (int i = 0; i < args.length; i++) { size += TypeDesc.isDoubleWord(args[i]) ? 2 : 1; } numWordsInSig = size; } return numWordsInSig; } /** * Generate code for only those catch blocks that are reachable * from one or more pausable blocks. fiber.pc tells us which * nested call possibly caused an exception, fiber.status tells us * whether there is any state that needs to be restored, and * fiber.curState gives us access to that state. * * ; Figure out which pausable method could have caused this. * * switch (fiber.upEx()) { * 0: goto NORMAL_EXCEPTION_HANDLING; * 2: goto RESTORE_F * } * RESTORE_F: * if (fiber.curStatus == HAS_STATE) { * restore variables from the state. don't restore stack * goto NORMAL_EXCEPTION_HANDLING * } * ... other RESTOREs * * NORMAL_EXCEPTION_HANDLING: */ private void genException(MethodVisitor mv, BasicBlock bb, List<CallWeaver> cwList) { bb.startLabel.accept(mv); LabelNode resumeLabel = new LabelNode(); VMType.loadVar(mv, VMType.TOBJECT, getFiberVar()); mv.visitMethodInsn(INVOKEVIRTUAL, FIBER_CLASS, "upEx", "()I"); // fiber.pc is on stack LabelNode[] labels = new LabelNode[cwList.size()]; int[] keys = new int[cwList.size()]; for (int i = 0; i < cwList.size(); i++) { labels[i] = new LabelNode(); keys[i] = callWeavers.indexOf(cwList.get(i)) + 1; } new LookupSwitchInsnNode(resumeLabel, keys, labels).accept(mv); int i = 0; for (CallWeaver cw: cwList) { if (i > 0) { // This is the jump (to normal exception handling) for the previous // switch case. mv.visitJumpInsn(GOTO, resumeLabel.getLabel()); } labels[i].accept(mv); cw.genRestoreEx(mv, labels[i]); i++; } // Consume the first instruction because we have already consumed the // corresponding label. (The standard visitInstructions code does a // visitLabel before visiting the instruction itself) resumeLabel.accept(mv); bb.getInstruction(bb.startPos).accept(mv); } int getFiberVar() { return fiberVar; // The first available slot } void visitTryCatchBlocks(MethodVisitor mv) { MethodFlow mf = methodFlow; ArrayList<BasicBlock> bbs = mf.getBasicBlocks(); ArrayList<Handler> allHandlers = new ArrayList<Handler>(bbs.size() * 2); for (BasicBlock bb : bbs) { allHandlers.addAll(bb.handlers); } allHandlers = Handler.consolidate(allHandlers); for (Handler h : allHandlers) { new TryCatchBlockNode(mf.getLabelAt(h.from), mf.getOrCreateLabelAtPos(h.to+1), h.catchBB.startLabel, h.type).accept(mv); } } void ensureMaxVars(int numVars) { if (numVars > maxVars) { maxVars = numVars; } } void ensureMaxStack(int numStack) { if (numStack > maxStack) { maxStack = numStack; } } int getPC(CallWeaver weaver) { for (int i = 0; i < callWeavers.size(); i++) { if (callWeavers.get(i) == weaver) return i + 1; } assert false : " No weaver found"; return 0; } public String createStateClass(ValInfoList valInfoList) { return classWeaver.createStateClass(valInfoList); } void makeNotWovenMethod(ClassVisitor cv, MethodFlow mf) { if (classWeaver.isInterface()) { MethodVisitor mv = cv.visitMethod(mf.access, mf.name, mf.desc, mf.signature, ClassWeaver.toStringArray(mf.exceptions)); mv.visitEnd(); } else { // Turn of abstract modifier int access = mf.access; access &= ~Constants.ACC_ABSTRACT; MethodVisitor mv = cv.visitMethod(access, mf.name, mf.desc, mf.signature, ClassWeaver.toStringArray(mf.exceptions)); mv.visitCode(); visitAttrs(mv); mv.visitMethodInsn(INVOKESTATIC, TASK_CLASS, "errNotWoven", "()V"); String rdesc = TypeDesc.getReturnTypeDesc(mf.desc); // stack size depends on return type, because we want to load // a constant of the appropriate size on the stack for // the corresponding xreturn instruction. int stacksize = 0; if (rdesc != D_VOID) { // ICONST_0; IRETURN or ACONST_NULL; ARETURN etc. stacksize = TypeDesc.isDoubleWord(rdesc) ? 2 : 1; int vmt = VMType.toVmType(rdesc); mv.visitInsn(VMType.constInsn[vmt]); mv.visitInsn(VMType.retInsn[vmt]); } else { mv.visitInsn(RETURN); } int numlocals; if ((mf.access & Constants.ACC_ABSTRACT) != 0) { // The abstract method doesn't contain the number of locals required to hold the // args, so we need to calculate it. numlocals = getNumWordsInSig() + 1 /* fiber */; if (!mf.isStatic()) numlocals++; } else { numlocals = mf.maxLocals + 1; } mv.visitMaxs(stacksize, numlocals); mv.visitEnd(); } } }