/* * JVSTM: a Java library for Software Transactional Memory * Copyright (C) 2005 INESC-ID Software Engineering Group * http://www.esw.inesc-id.pt * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free Software * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA * * Author's contact: * INESC-ID Software Engineering Group * Rua Alves Redol 9 * 1000 - 029 Lisboa * Portugal */ package jvstm.atomic; import static org.objectweb.asm.Opcodes.ACC_FINAL; import static org.objectweb.asm.Opcodes.ACC_PRIVATE; import static org.objectweb.asm.Opcodes.ACC_PUBLIC; import static org.objectweb.asm.Opcodes.ACC_PROTECTED; import static org.objectweb.asm.Opcodes.ACC_STATIC; import static org.objectweb.asm.Opcodes.ACONST_NULL; import static org.objectweb.asm.Opcodes.ALOAD; import static org.objectweb.asm.Opcodes.ARETURN; import static org.objectweb.asm.Opcodes.IRETURN; import static org.objectweb.asm.Opcodes.ICONST_0; import static org.objectweb.asm.Opcodes.ASM4; import static org.objectweb.asm.Opcodes.ASTORE; import static org.objectweb.asm.Opcodes.DUP; import static org.objectweb.asm.Opcodes.GETFIELD; import static org.objectweb.asm.Opcodes.ILOAD; import static org.objectweb.asm.Opcodes.INVOKESPECIAL; import static org.objectweb.asm.Opcodes.INVOKESTATIC; import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL; import static org.objectweb.asm.Opcodes.NEW; import static org.objectweb.asm.Opcodes.POP; import static org.objectweb.asm.Opcodes.PUTFIELD; import static org.objectweb.asm.Opcodes.RETURN; import static org.objectweb.asm.Opcodes.V1_6; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Type; import org.objectweb.asm.tree.AbstractInsnNode; import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.InnerClassNode; import org.objectweb.asm.tree.InsnList; import org.objectweb.asm.tree.InsnNode; import org.objectweb.asm.tree.MethodInsnNode; import org.objectweb.asm.tree.MethodNode; import org.objectweb.asm.tree.TypeInsnNode; import org.objectweb.asm.tree.VarInsnNode; public class ProcessParNestAnnotations { private static final Type UNSAFE_SPAWN = Type.getType(jvstm.atomic.UnsafeSpawn.class); private static final Type PARALLEL_SPAWN = Type.getType(jvstm.atomic.ParallelSpawn.class); private static final Type COMBINER = Type.getType(jvstm.atomic.Combiner.class); private static final Type PAR_NEST = Type.getType(jvstm.atomic.ParNest.class); private static final Type ARRAY_LIST = Type.getType(java.util.ArrayList.class); private static final Type TRANSACTION = Type.getType(jvstm.Transaction.class); private final String[] files; public ProcessParNestAnnotations(String[] files) { this.files = files; } public static void main(final String args[]) throws Exception { new ProcessParNestAnnotations(args).process(); } public void process() { for (String file : files) { processFile(new File(file)); } } public static void processFile(File file) { if (file.isDirectory()) { for (File subFile : file.listFiles()) { processFile(subFile); } } else { String fileName = file.getName(); if (fileName.endsWith(".class")) { processClassFile(file); } } } private static List<String> alreadyProcessed; // maps the method to the callable name private static Map<String, String> callablesCreated; private static String createUniqueMethodName(String methodName) { alreadyProcessed.add(methodName); int count = 0; for (String name : alreadyProcessed) { if (methodName.equals(name)) { count++; } } return methodName + (count > 0 ? "$" + count : ""); } protected static void processClassFile(File classFile) { alreadyProcessed = new ArrayList<String>(); callablesCreated = new HashMap<String, String>(); InputStream is = null; try { // get an input stream to read the bytecode of the class is = new FileInputStream(classFile); ClassNode cn = new ClassNode(ASM4); ClassReader cr = new ClassReader(is); cr.accept(cn, 0); List<MethodNode> parNestedMethods = new ArrayList<MethodNode>(); MethodNode combinerMethod = null; MethodNode execMethod = null; List<MethodNode> staticMethodsToAdd = new ArrayList<MethodNode>(); boolean parallelSpawn = extendsParallelSpawn(cn); boolean unsafeSpawn = extendsUnsafeSpawn(cn); if (parallelSpawn || unsafeSpawn) { Iterator<MethodNode> methodIter = cn.methods.iterator(); while (methodIter.hasNext()) { MethodNode mn = methodIter.next(); if (mn.name.equals("exec") && execMethod == null) { execMethod = mn; continue; } if (mn.invisibleAnnotations == null) { continue; } for (AnnotationNode an : mn.invisibleAnnotations) { if (an.desc.equals(PAR_NEST.getDescriptor())) { // Ensure the method can be called from outside mn.access = (mn.access & ~ACC_PRIVATE) | ACC_PUBLIC; parNestedMethods.add(mn); String uniqueMethodName = createUniqueMethodName(mn.name); String callableClass; if (parallelSpawn) { callableClass = cn.name + "$nested$work$unit$" + uniqueMethodName; } else { callableClass = cn.name + "$unsafe$work$unit$" + uniqueMethodName; } callablesCreated.put(mn.name, callableClass); boolean readOnlyCallable = ( an.values == null ) ? false : (Boolean) an.values.get(1); generateCallable(classFile, cn.name, callableClass, mn, readOnlyCallable, unsafeSpawn); staticMethodsToAdd.add(generateStaticCallableCreation(cn, cn.name, callableClass, mn)); break; } else if (an.desc.equals(COMBINER.getDescriptor())) { if (combinerMethod != null) { throw new RuntimeException("Class: " + cn.name + " contains two @Combiner methods: " + combinerMethod.name + " and " + mn.name); } combinerMethod = mn; } } } // TODO Verify the @Combiner method // The return should be of the same type of the parameterization // of the ParallelSpawn for (MethodNode methodToAdd : staticMethodsToAdd) { cn.methods.add(methodToAdd); } if (alreadyProcessed.size() == 0) { throw new RuntimeException("Class: " + cn.name + " must have at least one method annotated with @ParNested"); } if (combinerMethod == null) { throw new RuntimeException("Class: " + cn.name + " must have one method annotated with @Combiner"); } List<Integer> localVariablesIdx = new ArrayList<Integer>(); int numberLocalVariables = 0; int listIndex = execMethod.maxLocals; execMethod.maxLocals++; InsnList preamble = new InsnList(); preamble.add(new TypeInsnNode(NEW, ARRAY_LIST.getInternalName())); preamble.add(new InsnNode(DUP)); preamble.add(new MethodInsnNode(INVOKESPECIAL, ARRAY_LIST.getInternalName(), "<init>", "()V")); preamble.add(new VarInsnNode(ASTORE, listIndex)); Iterator<AbstractInsnNode> execInstIter = execMethod.instructions.iterator(); while (execInstIter.hasNext()) { AbstractInsnNode instr = execInstIter.next(); // Look out for calls to methods if (instr.getOpcode() == INVOKEVIRTUAL || instr.getOpcode() == INVOKESPECIAL) { MethodInsnNode methodInstr = (MethodInsnNode) instr; // Is method being called annotated with @ParNested for (MethodNode parNestedMethod : parNestedMethods) { if (parNestedMethod.name.equals(methodInstr.name)) { numberLocalVariables++; } } } } for (int i = 0; i < numberLocalVariables; i++) { localVariablesIdx.add(i, execMethod.maxLocals); execMethod.maxLocals++; } int callablesManipulated = 0; execInstIter = execMethod.instructions.iterator(); while (execInstIter.hasNext()) { AbstractInsnNode instr = execInstIter.next(); // Look out for calls to methods if (instr.getOpcode() != INVOKEVIRTUAL && instr.getOpcode() != INVOKESPECIAL) { continue; } MethodInsnNode methodInstr = (MethodInsnNode) instr; // Is method being called annotated with @ParNested boolean isParNestedMethod = false; for (MethodNode parNestedMethod : parNestedMethods) { if (parNestedMethod.name.equals(methodInstr.name)) { isParNestedMethod = true; break; } } if (!isParNestedMethod) { continue; } // Let's change this call // If it was a call to: @ParNested public int add(int i1, // int i2) // add(foo, bar) -> add$static$callable$creator(this, foo, // bar) // the 'this' will be already in the right place in the // stack // because the method being called now is static whereas // previously // it was not methodInstr.setOpcode(INVOKESTATIC); methodInstr.name = methodInstr.name + "$static$callable$creator"; for (MethodNode staticCreated : staticMethodsToAdd) { if (staticCreated.name.equals(methodInstr.name)) { methodInstr.desc = staticCreated.desc; break; } } InsnList midterm = new InsnList(); // Store the callable instantiated in local variable midterm.add(new VarInsnNode(ASTORE, localVariablesIdx.get(callablesManipulated))); // Load the list midterm.add(new VarInsnNode(ALOAD, listIndex)); // Load the callable midterm.add(new VarInsnNode(ALOAD, localVariablesIdx.get(callablesManipulated))); // Add it to the list midterm.add(new MethodInsnNode(INVOKEVIRTUAL, ARRAY_LIST.getInternalName(), "add", "(Ljava/lang/Object;)Z")); // Pop the boolean that results from the add(Object) // May reuse a POP if the previous call had a return if (methodInstr.getNext().getOpcode() != POP) { midterm.add(new InsnNode(POP)); } // Add this set of instructions after the call to the // constrution of the callable execMethod.instructions.insert(methodInstr, midterm); callablesManipulated++; } // Insert the preamble in the start execMethod.instructions.insert(preamble); InsnList finish = new InsnList(); // Push 'this' for the call to the combiner method finish.add(new VarInsnNode(ALOAD, 0)); // Call the static method current() of jvstm.Transaction finish.add(new MethodInsnNode(INVOKESTATIC, TRANSACTION.getInternalName(), "current", "()Ljvstm/Transaction;")); // Load the callables list finish.add(new VarInsnNode(ALOAD, listIndex)); // Call the manage parnested method finish.add(new MethodInsnNode(INVOKEVIRTUAL, TRANSACTION.getInternalName(), "manageNestedParallelTxs", "(Ljava/util/List;)Ljava/util/List;")); // Call the combiner method finish.add(new MethodInsnNode(INVOKEVIRTUAL, cn.name, combinerMethod.name, combinerMethod.desc)); // Return what the combiner returns finish.add(new InsnNode(ARETURN)); // Remove the "return null" that's supposed to be at the end of // the exec method execInstIter = execMethod.instructions.iterator(); while (execInstIter.hasNext()) { AbstractInsnNode curNode = execInstIter.next(); if (!execInstIter.hasNext()) { // Insert the finish in the end execMethod.instructions.insert(curNode.getPrevious().getPrevious(), finish); execMethod.instructions.remove(curNode.getPrevious()); execMethod.instructions.remove(curNode); break; } } } ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS); cn.accept(cw); writeClassFile(classFile, cw.toByteArray()); } catch (IOException e) { throw new Error("Error processing class file", e); } finally { if (is != null) { try { is.close(); } catch (IOException e) { } } } } private static boolean isStatic(MethodNode mn) { return (mn.access & ACC_STATIC) > 0; } private static boolean isPrimitive(Type type) { int sort = type.getSort(); return sort != Type.VOID && sort != Type.ARRAY && sort != Type.OBJECT && sort != Type.METHOD; } private static final Object[][] primitiveWrappers = new Object[][] { { "java/lang/Boolean", Type.BOOLEAN_TYPE }, { "java/lang/Byte", Type.BYTE_TYPE }, { "java/lang/Character", Type.CHAR_TYPE }, { "java/lang/Short", Type.SHORT_TYPE }, { "java/lang/Integer", Type.INT_TYPE }, { "java/lang/Long", Type.LONG_TYPE }, { "java/lang/Float", Type.FLOAT_TYPE }, { "java/lang/Double", Type.DOUBLE_TYPE } }; private static Type toObject(Type primitiveType) { for (Object[] map : primitiveWrappers) { if (primitiveType.equals(map[1])) return Type.getObjectType((String) map[0]); } throw new AssertionError(); } private static String getCallableCtorDesc(String className, MethodNode mn) { List<Type> callableCtorDescList = new ArrayList<Type>(); if (!isStatic(mn)) callableCtorDescList.add(Type.getObjectType(className)); callableCtorDescList.addAll(Arrays.asList(Type.getArgumentTypes(mn.desc))); String callableCtorDesc = Type.getMethodDescriptor(Type.VOID_TYPE, callableCtorDescList.toArray(new Type[0])); return callableCtorDesc; } private static MethodNode generateStaticCallableCreation(ClassNode classNode, String className, String callableClass, MethodNode mn) { MethodNode staticMethod = new MethodNode(V1_6, mn.access | ACC_STATIC, mn.name + "$static$callable$creator", "(L" + className + ";" + mn.desc.substring(1, mn.desc.indexOf(')') + 1) + "L" + callableClass + ";", mn.signature, new String[0]); InsnList content = new InsnList(); content.add(new TypeInsnNode(NEW, callableClass)); content.add(new InsnNode(DUP)); int pos = 0; // Push the instance of the class being modified (first argument of this // synthetized method) content.add(new VarInsnNode(ALOAD, pos++)); // Push arguments of original method on the stack for callable creation for (Type t : Type.getArgumentTypes(mn.desc)) { content.add(new VarInsnNode(t.getOpcode(ILOAD), pos)); pos += t.getSize(); } // Instantiate the callable content.add(new MethodInsnNode(INVOKESPECIAL, callableClass, "<init>", getCallableCtorDesc(className, mn))); // Return it from the static method content.add(new InsnNode(ARETURN)); staticMethod.instructions.add(content); return staticMethod; } private static void generateCallable(File classFile, String className, String callableClass, MethodNode mn, boolean readOnly, boolean unsafe) { Type returnType = Type.getReturnType(mn.desc); List<Type> arguments = new ArrayList<Type>(Arrays.asList(Type.getArgumentTypes(mn.desc))); if (!isStatic(mn)) arguments.add(0, Type.getObjectType(className)); ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS); cw.visit( V1_6, ACC_FINAL, callableClass, unsafe ? "Ljvstm/UnsafeParallelTask<" : "Ljvstm/ParallelTask<" + (isPrimitive(returnType) ? toObject(returnType) : (returnType.equals(Type.VOID_TYPE) ? Type .getObjectType("java/lang/Void") : returnType)).getDescriptor() + ">;", unsafe ? "jvstm/UnsafeParallelTask" : "jvstm/ParallelTask", new String[] {}); cw.visitSource("JVSTM Generated Wrapper Class", null); // Create fields to hold arguments { int fieldPos = 0; for (Type t : arguments) { cw.visitField(ACC_PRIVATE | ACC_FINAL, "arg" + (fieldPos++), t.getDescriptor(), null, null); } } // Create constructor { MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "<init>", getCallableCtorDesc(className, mn), null, null); mv.visitCode(); mv.visitVarInsn(ALOAD, 0); mv.visitMethodInsn(INVOKESPECIAL, unsafe ? "jvstm/UnsafeParallelTask" : "jvstm/ParallelTask", "<init>", "()V"); int localsPos = 0; int fieldPos = 0; for (Type t : arguments) { mv.visitVarInsn(ALOAD, 0); mv.visitVarInsn(t.getOpcode(ILOAD), localsPos + 1); mv.visitFieldInsn(PUTFIELD, callableClass, "arg" + fieldPos++, t.getDescriptor()); localsPos += t.getSize(); } mv.visitInsn(RETURN); mv.visitMaxs(0, 0); mv.visitEnd(); } // Create execute method { MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "execute", "()Ljava/lang/Object;", null, null); mv.visitCode(); int fieldPos = 0; for (Type t : arguments) { mv.visitVarInsn(ALOAD, 0); mv.visitFieldInsn(GETFIELD, callableClass, "arg" + fieldPos++, t.getDescriptor()); } mv.visitMethodInsn(isStatic(mn) ? INVOKESTATIC : INVOKEVIRTUAL, className, mn.name, mn.desc); if (returnType.equals(Type.VOID_TYPE)) mv.visitInsn(ACONST_NULL); else if (isPrimitive(returnType)) boxWrap(returnType, mv); mv.visitInsn(ARETURN); mv.visitMaxs(0, 0); mv.visitEnd(); } // Create the readOnly method { if (readOnly) { MethodVisitor mv = cw.visitMethod(ACC_PROTECTED, "isReadOnly", "()Z", null, null); mv.visitCode(); mv.visitInsn(ICONST_0); mv.visitInsn(IRETURN); mv.visitMaxs(0, 0); mv.visitEnd(); } } /* protected boolean isReadOnly() { return false; }*/ // Write the callable class file in the same directory as the original // class file String callableFileName = callableClass.substring(Math.max(callableClass.lastIndexOf('/'), 0)) + ".class"; writeClassFile(new File((classFile.getParent() == null ? "" : classFile.getParent() + File.separatorChar) + callableFileName), cw.toByteArray()); } private static void boxWrap(Type primitiveType, MethodVisitor mv) { Type objectType = toObject(primitiveType); mv.visitMethodInsn(INVOKESTATIC, objectType.getInternalName(), "valueOf", "(" + primitiveType.getDescriptor() + ")" + objectType.getDescriptor()); } private static boolean extendsParallelSpawn(ClassNode cn) { // TODO Support extending ParallelSpawn at multiple levels, the // issue at the moment is that the exec() method could be // elsewhere... // Plus, if we attempt to check on a class that extends some // class whose .class is not in the project (Thread.class for instance) // we run into problems for (String implementedInterfaceName : cn.interfaces) { if (implementedInterfaceName.equals(PARALLEL_SPAWN.getInternalName())) { return true; } } return false; } private static boolean extendsUnsafeSpawn(ClassNode cn) { for (String implementedInterfaceName : cn.interfaces) { if (implementedInterfaceName.equals(UNSAFE_SPAWN.getInternalName())) { return true; } } return false; } protected static void writeClassFile(File classFile, byte[] bytecode) { FileOutputStream fos = null; try { fos = new FileOutputStream(classFile); fos.write(bytecode); } catch (IOException e) { throw new Error("Couldn't write class file", e); } finally { if (fos != null) { try { fos.close(); } catch (IOException e) { } } } } static class ParNestMethodTransformer extends ClassVisitor { private final List<MethodNode> methods = new ArrayList<MethodNode>(); private final List<String> parNestMethodNames = new ArrayList<String>(); private final MethodNode atomicClInit; private final File classFile; private String className; public ParNestMethodTransformer(ClassVisitor cv, File originalClassFile) { super(ASM4, cv); classFile = originalClassFile; atomicClInit = new MethodNode(ACC_STATIC, "<clinit>", "()V", null, null); atomicClInit.visitCode(); } @Override public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { className = name; System.err.println("Class: " + name); cv.visit(version, access, name, signature, superName, interfaces); } @Override public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) { // Use a MethodNode to represent the method MethodNode mn = new MethodNode(access, name, desc, signature, exceptions); methods.add(mn); return mn; } @Override public void visitInnerClass(String name, String outerName, String innerName, int access) { System.err.println("Inner Class: " + name + " inner name: " + innerName + " outer: " + outerName); InnerClassNode n = new InnerClassNode(name, outerName, innerName, access); cv.visitInnerClass(name, outerName, innerName, access); } @Override public void visitEnd() { MethodNode clInit = null; boolean hasParNest = false; for (MethodNode mn : methods) { if (mn.name.equals("<clinit>")) { clInit = mn; continue; } if (mn.invisibleAnnotations != null) { for (AnnotationNode an : mn.invisibleAnnotations) { if (an.desc.equals(PAR_NEST.getDescriptor())) { System.out.println("Method " + mn.name + " is tagged with @ParNest"); hasParNest = true; // Create new transactified method // transactify(mn, an); break; } } } // Visit method, so it will be present on the output class mn.accept(cv); } if (hasParNest) { // Insert <clinit> into class if (clInit != null) { // Merge existing clinit with our additions clInit.instructions.accept(atomicClInit); } else { atomicClInit.visitInsn(RETURN); } atomicClInit.visitMaxs(0, 0); atomicClInit.visitEnd(); atomicClInit.accept(cv); } else { // Preserve existing <clinit> if (clInit != null) clInit.accept(cv); } cv.visitEnd(); } private String getMethodName(String methodName) { // Count number of atomic methods with same name int count = 0; for (String name : parNestMethodNames) { if (name.equals(methodName)) count++; } // Add another one parNestMethodNames.add(methodName); return methodName + (count > 0 ? "$" + count : ""); } } }