package pt.ist.fenixframework.atomic; import java.io.*; import java.util.*; import org.objectweb.asm.*; import org.objectweb.asm.tree.*; import static org.objectweb.asm.Opcodes.*; public class ProcessAtomicAnnotations { private static final Type ATOMIC = Type.getType(pt.ist.fenixframework.Atomic.class); private static final Type ATOMIC_CONTEXT = Type.getType(AtomicContext.class); private static final Type ATOMIC_INSTANCE = Type.getObjectType("pt/ist/fenixframework/atomic/AtomicInstance"); private static final Map<String,Object> ATOMIC_ELEMENTS; private static final List<FieldNode> ATOMIC_FIELDS; private static final String ATOMIC_INSTANCE_CTOR_DESC; static { Map<String,Object> atomicElements = new HashMap<String,Object>(); for (java.lang.reflect.Method element : pt.ist.fenixframework.Atomic.class.getDeclaredMethods()) { Object defaultValue = element.getDefaultValue(); if (defaultValue instanceof Class) { defaultValue = Type.getType((Class<?>) defaultValue); } atomicElements.put(element.getName(), defaultValue); } ATOMIC_ELEMENTS = Collections.unmodifiableMap(atomicElements); try { InputStream is = Thread.currentThread().getContextClassLoader() .getResourceAsStream(ATOMIC_INSTANCE.getInternalName() + ".class"); ClassReader cr = new ClassReader(is); ClassNode cNode = new ClassNode(); cr.accept(cNode, 0); ATOMIC_FIELDS = cNode.fields != null ? cNode.fields : Collections.<FieldNode>emptyList(); StringBuffer ctorDescriptor = new StringBuffer("("); for (FieldNode field : ATOMIC_FIELDS) ctorDescriptor.append(field.desc); ctorDescriptor.append(")V"); ATOMIC_INSTANCE_CTOR_DESC = ctorDescriptor.toString(); } catch (IOException e) { throw new Error("Error opening AtomicInstance class. Have you run GenerateAtomicInstance?", e); } } private ProcessAtomicAnnotations() {} public static void main (final String args[]) throws Exception { for (String file : args) { ProcessAtomicAnnotations.processFile(new File(file)); } } public static void processFiles(File [] files) { for (File file : files) { processFile(file); } } public static void processFile(File file) { if (file.isDirectory()) { for (File subFile : file.listFiles()) { processFile(subFile); } } else { String fileName = file.getName(); if (fileName.endsWith(".class")) { processClassFile(file); } } } protected static void processClassFile(File classFile) { InputStream is = null; try { // get an input stream to read the bytecode of the class is = new FileInputStream(classFile); ClassReader cr = new ClassReader(is); ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS); ClassVisitor cv = cw; // Add here other visitors to run AFTER the AtomicMethodTransformer cv = new AtomicMethodTransformer(cv, classFile); // Add here other visitors to run BEFORE the AtomicMethodTransformer cr.accept(cv, 0); writeClassFile(classFile, cw.toByteArray()); } catch (IOException e) { throw new Error("Error processing class file", e); } finally { if (is != null) { try { is.close(); } catch (IOException e) { } } } } protected static void writeClassFile(File classFile, byte[] bytecode) { FileOutputStream fos = null; try { fos = new FileOutputStream(classFile); fos.write(bytecode); } catch (IOException e) { throw new Error("Couldn't write class file", e); } finally { if (fos != null) { try { fos.close(); } catch (IOException e) { } } } } static class AtomicMethodTransformer extends ClassVisitor { private final List<MethodNode> methods = new ArrayList<MethodNode>(); private final List<String> atomicMethodNames = new ArrayList<String>(); private final MethodNode atomicClInit; private final File classFile; private String className; public AtomicMethodTransformer(ClassVisitor cv, File originalClassFile) { super(ASM4, cv); classFile = originalClassFile; atomicClInit = new MethodNode(ACC_STATIC, "<clinit>", "()V", null, null); atomicClInit.visitCode(); } @Override public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { className = name; cv.visit(version, access, name, signature, superName, interfaces); } @Override public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) { // Use a MethodNode to represent the method MethodNode mn = new MethodNode(access, name, desc, signature, exceptions); methods.add(mn); return mn; } @Override public void visitEnd() { MethodNode clInit = null; boolean hasAtomic = false; for (MethodNode mn : methods) { if (mn.name.equals("<clinit>")) { clInit = mn; continue; } if (mn.invisibleAnnotations != null) { for (AnnotationNode an : mn.invisibleAnnotations) { if (an.desc.equals(ATOMIC.getDescriptor())) { //System.out.println("Method " + mn.name + " is tagged with @Atomic"); hasAtomic = true; // Create new transactified method transactify(mn, an); break; } } } // Visit method, so it will be present on the output class mn.accept(cv); } if (hasAtomic) { // Insert <clinit> into class if (clInit != null) { // Merge existing clinit with our additions clInit.instructions.accept(atomicClInit); } else { atomicClInit.visitInsn(RETURN); } atomicClInit.visitMaxs(0, 0); atomicClInit.visitEnd(); atomicClInit.accept(cv); } else { // Preserve existing <clinit> if (clInit != null) clInit.accept(cv); } cv.visitEnd(); } /** * To transactify a method, considering that the original method was * "@Atomic @SomethingElse long add(Object o, int i)", and part of the class "Xpto", * we generate the following code: * public static [final] AtomicContext context$add = pt.ist.fenixframework.atomic.AtomicContext.newContext(); * * @SomethingElse * long add(Object o, int i) { * static class pt$ist$fenixframework$callable$add implements Callable { * Xpto arg0; * Object arg1; * int arg2; * * pt$ist$fenixframework$callable$add(Xpto x, Object o, int i) { * arg0 = x; * arg1 = o; * arg2 = i; * } * * Object call() { * return arg0.atomic$add(arg1, arg2); * } * } * return context$add.doTransactionally(new pt$ist$fenixframework$callable$add(this, o, i)); * } * * long atomic$add(Object o, int i) { * // original method * } * Note that any annotations from the original method (aside from @Atomic) are removed from the * atomic$ version. **/ private void transactify(MethodNode mn, AnnotationNode atomicAnnotation) { // Mangle name if there are multiple atomic methods with the same name String methodName = getMethodName(mn.name); // Name for context field String fieldName = "context$" + methodName; // Name for callable class String callableClass = className + "$pt$ist$fenixframework$callable$" + methodName; // Generate new method which will invoke the context with the Callable MethodVisitor atomicMethod = cv.visitMethod(mn.access, mn.name, mn.desc, mn.signature, mn.exceptions.toArray(new String[0])); // Remove @Atomic annotation mn.invisibleAnnotations.remove(atomicAnnotation); // Copy other annotations from the original method to the newly created method for (AnnotationNode an : mn.invisibleAnnotations) { an.accept(atomicMethod.visitAnnotation(an.desc, false)); } if (mn.visibleAnnotations != null) { for (AnnotationNode an : mn.visibleAnnotations) { an.accept(atomicMethod.visitAnnotation(an.desc, true)); } } // Generate replacement method generateMethodCode(mn, atomicMethod, fieldName, callableClass); // Create field to save context cv.visitField(ACC_PUBLIC | ACC_STATIC | ACC_FINAL, fieldName, ATOMIC_CONTEXT.getDescriptor(), null, null); // Generate callable class generateCallable(callableClass, mn); // Add code to clinit to initialize the field // Add default parameters from @Atomic Map<String,Object> atomicElements = new HashMap<String,Object>(ATOMIC_ELEMENTS); // Copy parameters from method annotation if (atomicAnnotation.values != null) { Iterator<Object> it = atomicAnnotation.values.iterator(); while (it.hasNext()) { // ASM stores annotation values as String1, Object1, String2, Object2, ... in the values list atomicElements.put((String) it.next(), it.next()); } } // Push @Atomic parameters on the stack and create AtomicInstance atomicClInit.visitTypeInsn(NEW, ATOMIC_INSTANCE.getInternalName()); atomicClInit.visitInsn(DUP); for (FieldNode field : ATOMIC_FIELDS) { atomicClInit.visitLdcInsn(atomicElements.get(field.name)); } atomicClInit.visitMethodInsn(INVOKESPECIAL, ATOMIC_INSTANCE.getInternalName(), "<init>", ATOMIC_INSTANCE_CTOR_DESC); // Obtain atomic context for this method atomicClInit.visitMethodInsn(INVOKESTATIC, ((Type) atomicElements.get("contextFactory")).getInternalName(), "newContext", "(" + ATOMIC.getDescriptor() + ")" + ATOMIC_CONTEXT.getDescriptor()); atomicClInit.visitFieldInsn(PUTSTATIC, className, fieldName, ATOMIC_CONTEXT.getDescriptor()); // Rename original method mn.name = "atomic$" + mn.name; // Remove annotations from original method mn.invisibleAnnotations = Collections.<AnnotationNode>emptyList(); mn.visibleAnnotations = Collections.<AnnotationNode>emptyList(); // If the method was private, it now becomes package protected, the callable can access it mn.access &= ~ACC_PRIVATE; } private void generateMethodCode(MethodNode mn, MethodVisitor mv, String fieldName, String callableClass) { mv.visitCode(); mv.visitFieldInsn(GETSTATIC, className, fieldName, ATOMIC_CONTEXT.getDescriptor()); mv.visitTypeInsn(NEW, callableClass); mv.visitInsn(DUP); int pos = 0; // Push arguments for original method on the stack if (!isStatic(mn)) mv.visitVarInsn(ALOAD, pos++); for (Type t : Type.getArgumentTypes(mn.desc)) { mv.visitVarInsn(t.getOpcode(ILOAD), pos); pos += t.getSize(); } mv.visitMethodInsn(INVOKESPECIAL, callableClass, "<init>", getCallableCtorDesc(mn)); mv.visitMethodInsn(INVOKEINTERFACE, ATOMIC_CONTEXT.getInternalName(), "doTransactionally", "(Ljava/util/concurrent/Callable;)Ljava/lang/Object;"); // Return value Type returnType = Type.getReturnType(mn.desc); if (returnType.getSort() == Type.OBJECT || returnType.getSort() == Type.ARRAY) { mv.visitTypeInsn(CHECKCAST, returnType.getInternalName()); } else if (isPrimitive(returnType)) { // Return is native, we have to unbox the value from the AtomicContext boxUnwrap(returnType, mv); } mv.visitInsn(returnType.getOpcode(IRETURN)); mv.visitMaxs(0, 0); mv.visitEnd(); } private static boolean isStatic(MethodNode mn) { return (mn.access & ACC_STATIC) > 0; } private String getCallableCtorDesc(MethodNode mn) { List<Type> callableCtorDescList = new ArrayList<Type>(); if (!isStatic(mn)) callableCtorDescList.add(Type.getObjectType(className)); callableCtorDescList.addAll(Arrays.asList(Type.getArgumentTypes(mn.desc))); String callableCtorDesc = Type.getMethodDescriptor(Type.VOID_TYPE, callableCtorDescList.toArray(new Type[0])); return callableCtorDesc; } private String getMethodName(String methodName) { // Count number of atomic methods with same name int count = 0; for (String name : atomicMethodNames) { if (name.equals(methodName)) count++; } // Add another one atomicMethodNames.add(methodName); return methodName + (count > 0 ? "$" + count : ""); } private void generateCallable(String callableClass, MethodNode mn) { Type returnType = Type.getReturnType(mn.desc); List<Type> arguments = new ArrayList<Type>(Arrays.asList(Type.getArgumentTypes(mn.desc))); if (!isStatic(mn)) arguments.add(0, Type.getObjectType(className)); ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS); cw.visit(V1_6, ACC_FINAL, callableClass, "Ljava/lang/Object;Ljava/util/concurrent/Callable<" + (isPrimitive(returnType) ? toObject(returnType) : (returnType.equals(Type.VOID_TYPE) ? Type.getObjectType("java/lang/Void") : returnType)).getDescriptor() + ">;", "java/lang/Object", new String[] { "java/util/concurrent/Callable" }); cw.visitSource("Generated Wrapper Class", null); // Create fields to hold arguments { int fieldPos = 0; for (Type t : arguments) { cw.visitField(ACC_PRIVATE | ACC_FINAL, "arg" + (fieldPos++), t.getDescriptor(), null, null); } } // Create constructor { MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "<init>", getCallableCtorDesc(mn), null, null); mv.visitCode(); mv.visitVarInsn(ALOAD, 0); mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V"); int localsPos = 0; int fieldPos = 0; for (Type t : arguments) { mv.visitVarInsn(ALOAD, 0); mv.visitVarInsn(t.getOpcode(ILOAD), localsPos+1); mv.visitFieldInsn(PUTFIELD, callableClass, "arg" + fieldPos++, t.getDescriptor()); localsPos += t.getSize(); } mv.visitInsn(RETURN); mv.visitMaxs(0, 0); mv.visitEnd(); } // Create call method { // Note: Usually when in Java you implement an interface with generics, such as Callable<Xpto>, // javac generates a Xpto call() method and an Object call() tagged as "public bridge synthetic" // that calls the previous one. Here, we generate the non-generic version immediately. MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "call", "()Ljava/lang/Object;", null, null); mv.visitCode(); int fieldPos = 0; for (Type t : arguments) { mv.visitVarInsn(ALOAD, 0); mv.visitFieldInsn(GETFIELD, callableClass, "arg" + fieldPos++, t.getDescriptor()); } mv.visitMethodInsn(isStatic(mn) ? INVOKESTATIC : INVOKEVIRTUAL, className, "atomic$" + mn.name, mn.desc); if (returnType.equals(Type.VOID_TYPE)) mv.visitInsn(ACONST_NULL); else if (isPrimitive(returnType)) boxWrap(returnType, mv); mv.visitInsn(ARETURN); mv.visitMaxs(0, 0); mv.visitEnd(); } // Write the callable class file in the same directory as the original class file String callableFileName = callableClass.substring(Math.max(callableClass.lastIndexOf('/'), 0)) + ".class"; writeClassFile(new File(classFile.getParent() + File.separatorChar + callableFileName), cw.toByteArray()); } private static final Object[][] primitiveWrappers = new Object[][] { {"java/lang/Boolean", Type.BOOLEAN_TYPE}, {"java/lang/Byte", Type.BYTE_TYPE}, {"java/lang/Character", Type.CHAR_TYPE}, {"java/lang/Short", Type.SHORT_TYPE}, {"java/lang/Integer", Type.INT_TYPE}, {"java/lang/Long", Type.LONG_TYPE}, {"java/lang/Float", Type.FLOAT_TYPE}, {"java/lang/Double", Type.DOUBLE_TYPE} }; private static Type toObject(Type primitiveType) { for (Object[] map : primitiveWrappers) { if (primitiveType.equals(map[1])) return Type.getObjectType((String) map[0]); } throw new AssertionError(); } private static boolean isPrimitive(Type type) { int sort = type.getSort(); return sort != Type.VOID && sort != Type.ARRAY && sort != Type.OBJECT && sort != Type.METHOD; } private static void boxWrap(Type primitiveType, MethodVisitor mv) { Type objectType = toObject(primitiveType); mv.visitMethodInsn(INVOKESTATIC, objectType .getInternalName(), "valueOf", "(" + primitiveType.getDescriptor() + ")" + objectType.getDescriptor()); } private static void boxUnwrap(Type primitiveType, MethodVisitor mv) { Type objectType = toObject(primitiveType); mv.visitTypeInsn(CHECKCAST, objectType.getInternalName()); mv.visitMethodInsn(INVOKEVIRTUAL, objectType.getInternalName(), primitiveType.getClassName() + "Value", "()" + primitiveType.getDescriptor()); } } }