/* This file is part of jpcsp. Jpcsp is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Jpcsp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Jpcsp. If not, see <http://www.gnu.org/licenses/>. */ package jpcsp.util; import static org.objectweb.asm.tree.AbstractInsnNode.JUMP_INSN; import static org.objectweb.asm.tree.AbstractInsnNode.LABEL; import static org.objectweb.asm.tree.AbstractInsnNode.LINE; import static org.objectweb.asm.tree.AbstractInsnNode.LOOKUPSWITCH_INSN; import static org.objectweb.asm.tree.AbstractInsnNode.TABLESWITCH_INSN; import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.ListIterator; import java.util.Set; import org.apache.log4j.Logger; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.Opcodes; import org.objectweb.asm.tree.AbstractInsnNode; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.FieldInsnNode; import org.objectweb.asm.tree.FieldNode; import org.objectweb.asm.tree.InsnNode; import org.objectweb.asm.tree.IntInsnNode; import org.objectweb.asm.tree.JumpInsnNode; import org.objectweb.asm.tree.LabelNode; import org.objectweb.asm.tree.LdcInsnNode; import org.objectweb.asm.tree.LookupSwitchInsnNode; import org.objectweb.asm.tree.MethodInsnNode; import org.objectweb.asm.tree.MethodNode; import org.objectweb.asm.tree.TableSwitchInsnNode; import org.objectweb.asm.tree.analysis.Analyzer; import org.objectweb.asm.tree.analysis.AnalyzerException; import org.objectweb.asm.tree.analysis.BasicInterpreter; import org.objectweb.asm.tree.analysis.Frame; import org.objectweb.asm.util.TraceClassVisitor; /** * @author gid15 * * Specialize a Java class by modifying its code to exclude part of it, * based on a list of field values dynamically computed at runtime. * The specialized class contains only the part of the code that will be * executed for the given field values, without the overhead for testing these values. * * For example, given the following class: * public class Test { * public static int testValue; * int test(int parameter) { * if (testValue == 0) { * return 0; * } else if (testValue < 0) { * return -parameter; * } else { * return parameter; * } * } * } * * A specialized class for the following field values: * testValue = 123; * would be * public class SpecialitedTest1 { * int test(int parameter) { * return parameter; * } * } * * and for * testValue = -123; * it would be * public class SpecialitedTest2 { * int test(int parameter) { * return -parameter; * } * } * * The following code statements can be evaluated by the specializer: * - if * - switch * - while * on field values of the following types: * - int * - byte * - short * - boolean * - float */ public class ClassSpecializer { private static Logger log = Logger.getLogger("classSpecializer"); private static SpecializedClassLoader classLoader = new SpecializedClassLoader(); private static HashSet<Class<?>> tracedClasses = new HashSet<Class<?>>(); public Class<?> specialize(String name, Class<?> c, HashMap<String, Object> variables) { ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); ClassVisitor cv = cw; StringWriter debugOutput = null; if (log.isTraceEnabled()) { // Dump the class to be specialized (only once) if (!tracedClasses.contains(c)) { StringWriter classTrace = new StringWriter(); ClassVisitor classTraceCv = new TraceClassVisitor(new PrintWriter(classTrace)); try { ClassReader cr = new ClassReader(c.getName().replace('.', '/')); cr.accept(classTraceCv, 0); log.trace(String.format("Dump of class to be specialized: %s", c.getName())); log.trace(classTrace); } catch (IOException e) { // Ignore Exception } tracedClasses.add(c); } log.trace(String.format("Specializing class %s", name)); String[] variableNames = variables.keySet().toArray(new String[variables.size()]); Arrays.sort(variableNames); for (String variableName : variableNames) { log.trace(String.format("Variable %s=%s", variableName, variables.get(variableName))); } debugOutput = new StringWriter(); PrintWriter debugPrintWriter = new PrintWriter(debugOutput); cv = new TraceClassVisitor(cv, debugPrintWriter); //cv = new TraceClassVisitor(debugPrintWriter); } try { ClassReader cr = new ClassReader(c.getName().replace('.', '/')); ClassNode cn = new SpecializedClassVisitor(name, variables); cr.accept(cn, 0); cn.accept(cv); } catch (IOException e) { log.error("Cannot read class", e); } if (debugOutput != null) { log.trace(debugOutput.toString()); } Class<?> specializedClass = null; try { specializedClass = classLoader.defineClass(name, cw.toByteArray()); } catch (ClassFormatError e) { log.error("Error while defining specialized class", e); } return specializedClass; } private static class SpecializedClassVisitor extends ClassNode { private final String specializedClassName; private final HashMap<String, Object> variables; private Object value; private AbstractInsnNode deleteUpToInsn; private String className; private String superClassName; public SpecializedClassVisitor(String specializedClassName, HashMap<String, Object> variables) { this.specializedClassName = specializedClassName; this.variables = variables; } @Override public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { className = name; superClassName = superName; // Define the specialized class as extending the original class super.visit(version, access, specializedClassName, signature, name, interfaces); } @Override public void visitEnd() { // Visit all the methods for (Iterator<?> it = methods.iterator(); it.hasNext(); ) { MethodNode method = (MethodNode) it.next(); visitMethod(method); } // Delete all the fields used as specialization variables for (ListIterator<?> lit = fields.listIterator(); lit.hasNext(); ) { FieldNode field = (FieldNode) lit.next(); if ((field.access & Opcodes.ACC_STATIC) != 0 && variables.containsKey(field.name)) { lit.remove(); } } super.visitEnd(); } private void visitMethod(MethodNode method) { final boolean isConstructor = "<init>".equals(method.name); deleteUpToInsn = null; for (ListIterator<?> lit = method.instructions.iterator(); lit.hasNext(); ) { AbstractInsnNode insn = (AbstractInsnNode) lit.next(); if (deleteUpToInsn != null) { if (insn == deleteUpToInsn) { deleteUpToInsn = null; } else { // Do not delete labels, they could be used as a target from a previous jump. // Also keep line numbers for easier debugging. if (insn.getType() != LABEL && insn.getType() != LINE) { lit.remove(); } continue; } } if (insn.getType() == AbstractInsnNode.FRAME) { // Remove all the FRAME information, they will be calculated // anew after the class specialization. lit.remove(); } else if (insn.getOpcode() == Opcodes.GETSTATIC) { FieldInsnNode fieldInsn = (FieldInsnNode) insn; if (variables.containsKey(fieldInsn.name)) { boolean processed = false; value = variables.get(fieldInsn.name); AbstractInsnNode nextInsn = insn.getNext(); if (analyseIfTestInt(method, insn)) { processed = true; } else if (nextInsn != null && nextInsn.getType() == TABLESWITCH_INSN) { TableSwitchInsnNode switchInsn = (TableSwitchInsnNode) nextInsn; LabelNode label = null; if (isIntValue(value)) { int n = getIntValue(value); if (n >= switchInsn.min && n <= switchInsn.max) { int i = n - switchInsn.min; if (i < switchInsn.labels.size()) { label = (LabelNode) switchInsn.labels.get(i); } } } if (label == null) { label = switchInsn.dflt; } if (label != null) { // Replace the table switch instruction by a GOTO to the switch label method.instructions.set(insn, new JumpInsnNode(Opcodes.GOTO, label)); processed = true; } } else if (nextInsn != null && nextInsn.getType() == LOOKUPSWITCH_INSN) { LookupSwitchInsnNode switchInsn = (LookupSwitchInsnNode) nextInsn; LabelNode label = null; if (isIntValue(value)) { int n = getIntValue(value); int i = 0; for (Object value : switchInsn.keys) { if (value instanceof Integer) { if (((Integer) value).intValue() == n) { label = (LabelNode) switchInsn.labels.get(i); break; } } i++; } } if (label == null) { label = switchInsn.dflt; } if (label != null) { // Replace the table switch instruction by a GOTO to the switch label method.instructions.set(insn, new JumpInsnNode(Opcodes.GOTO, label)); processed = true; } } else if (nextInsn != null && nextInsn.getType() == AbstractInsnNode.INSN) { int opcode = nextInsn.getOpcode(); int n = 0; float f = 0f; boolean isIntConstant = false; boolean isFloatConstant = false; switch (opcode) { case Opcodes.ICONST_M1: n = -1; isIntConstant = true; break; case Opcodes.ICONST_0: n = 0; isIntConstant = true; break; case Opcodes.ICONST_1: n = 1; isIntConstant = true; break; case Opcodes.ICONST_2: n = 2; isIntConstant = true; break; case Opcodes.ICONST_3: n = 3; isIntConstant = true; break; case Opcodes.ICONST_4: n = 4; isIntConstant = true; break; case Opcodes.ICONST_5: n = 5; isIntConstant = true; break; case Opcodes.FCONST_0: f = 0f; isFloatConstant = true; break; case Opcodes.FCONST_1: f = 1f; isFloatConstant = true; break; case Opcodes.FCONST_2: f = 2f; isFloatConstant = true; break; } if (isIntConstant) { if (analyseIfTestInt(method, insn, nextInsn, n)) { processed = true; } } else if (isFloatConstant) { if (analyseIfTestFloat(method, insn, nextInsn, f)) { processed = true; } } } else if (nextInsn != null && nextInsn.getType() == AbstractInsnNode.INT_INSN) { IntInsnNode intInsn = (IntInsnNode) nextInsn; if (analyseIfTestInt(method, insn, nextInsn, intInsn.operand)) { processed = true; } } else if (nextInsn != null && nextInsn.getType() == AbstractInsnNode.LDC_INSN) { LdcInsnNode ldcInsn = (LdcInsnNode) nextInsn; if (isIntValue(ldcInsn.cst)) { if (analyseIfTestInt(method, insn, nextInsn, getIntValue(ldcInsn.cst))) { processed = true; } } else if (isFloatValue(ldcInsn.cst)) { if (analyseIfTestFloat(method, insn, nextInsn, getFloatValue(ldcInsn.cst))) { processed = true; } } } if (!processed) { // Replace the field access by its constant value AbstractInsnNode constantInsn = getConstantInsn(value); if (constantInsn != null) { method.instructions.set(insn, constantInsn); } } } else { if (fieldInsn.owner.equals(className)) { // Replace the class name by the specialized class name fieldInsn.owner = specializedClassName; } } } else if (insn.getOpcode() == Opcodes.PUTSTATIC) { FieldInsnNode fieldInsn = (FieldInsnNode) insn; if (!variables.containsKey(fieldInsn.name)) { if (fieldInsn.owner.equals(className)) { // Replace the class name by the specialized class name fieldInsn.owner = specializedClassName; } } } else if (insn.getType() == AbstractInsnNode.METHOD_INSN) { MethodInsnNode methodInsn = (MethodInsnNode) insn; if (methodInsn.owner.equals(className)) { // Replace the class name by the specialized class name methodInsn.owner = specializedClassName; } else if (isConstructor && methodInsn.owner.equals(superClassName)) { // Update the call to the constructor of the parent class methodInsn.owner = className; } } } // Delete all the information about local variables, they are no longer correct // (the class loader would complain). method.localVariables.clear(); optimizeJumps(method); removeDeadCode(method); optimizeJumps(method); removeUnusedLabels(method); removeUselessLineNumbers(method); } /** * Optimize the jumps from a method: * - jumps to a "GOTO label" instruction * are replaced with a direct jump to "label"; * - a GOTO to the next instruction is deleted; * - a GOTO to a RETURN or ATHROW instruction * is replaced with this RETURN or ATHROW instruction. * * @param method the method to be optimized */ private void optimizeJumps(MethodNode method) { for (ListIterator<?> lit = method.instructions.iterator(); lit.hasNext(); ) { AbstractInsnNode insn = (AbstractInsnNode) lit.next(); if (insn.getType() == JUMP_INSN) { JumpInsnNode jumpInsn = (JumpInsnNode) insn; LabelNode label = jumpInsn.label; AbstractInsnNode target; // while target == goto l, replace label with l while (true) { target = label; while (target != null && target.getOpcode() < 0) { target = target.getNext(); } if (target != null && target.getOpcode() == Opcodes.GOTO) { label = ((JumpInsnNode) target).label; } else { break; } } // update target jumpInsn.label = label; boolean removeJump = false; if (jumpInsn.getOpcode() == Opcodes.GOTO) { // Delete a GOTO to the next instruction AbstractInsnNode next = jumpInsn.getNext(); while (next != null) { if (next == label) { removeJump = true; break; } else if (next.getOpcode() >= 0) { break; } next = next.getNext(); } } if (removeJump) { lit.remove(); } else { // if possible, replace jump with target instruction if (jumpInsn.getOpcode() == Opcodes.GOTO && target != null) { switch (target.getOpcode()) { case Opcodes.IRETURN: case Opcodes.LRETURN: case Opcodes.FRETURN: case Opcodes.DRETURN: case Opcodes.ARETURN: case Opcodes.RETURN: case Opcodes.ATHROW: // replace instruction with clone of target method.instructions.set(insn, target.clone(null)); } } } } } } /** * Remove the dead code - or unreachable code - from a method. * * @param method the method to be updated */ private void removeDeadCode(MethodNode method) { try { // Analyze the method using the BasicInterpreter. // As a result, the computed frames are null for instructions // that cannot be reached. Analyzer analyzer = new Analyzer(new BasicInterpreter()); analyzer.analyze(specializedClassName, method); Frame[] frames = analyzer.getFrames(); AbstractInsnNode[] insns = method.instructions.toArray(); for (int i = 0; i < frames.length; i++) { AbstractInsnNode insn = insns[i]; if (frames[i] == null && insn.getType() != AbstractInsnNode.LABEL) { // This instruction was not reached by the analyzer method.instructions.remove(insn); insns[i] = null; } } } catch (AnalyzerException e) { // Ignore error } } /** * Remove unused labels, i.e. labels that are not referenced. * * @param method the method to be updated */ private void removeUnusedLabels(MethodNode method) { // Scan for all the used labels Set<LabelNode> usedLabels = new HashSet<LabelNode>(); for (ListIterator<?> lit = method.instructions.iterator(); lit.hasNext(); ) { AbstractInsnNode insn = (AbstractInsnNode) lit.next(); if (insn.getType() == JUMP_INSN) { JumpInsnNode jumpInsn = (JumpInsnNode) insn; usedLabels.add(jumpInsn.label); } else if (insn.getType() == TABLESWITCH_INSN) { TableSwitchInsnNode tableSwitchInsn = (TableSwitchInsnNode) insn; for (Iterator<?> it = tableSwitchInsn.labels.iterator(); it.hasNext(); ) { LabelNode labelNode = (LabelNode) it.next(); if (labelNode != null) { usedLabels.add(labelNode); } } } else if (insn.getType() == LOOKUPSWITCH_INSN) { LookupSwitchInsnNode loopupSwitchInsn = (LookupSwitchInsnNode) insn; for (Iterator<?> it = loopupSwitchInsn.labels.iterator(); it.hasNext(); ) { LabelNode labelNode = (LabelNode) it.next(); if (labelNode != null) { usedLabels.add(labelNode); } } } } // Remove all the label instructions not being identified in the scan for (ListIterator<?> lit = method.instructions.iterator(); lit.hasNext(); ) { AbstractInsnNode insn = (AbstractInsnNode) lit.next(); if (insn.getType() == LABEL) { if (!usedLabels.contains(insn)) { lit.remove(); } } } } /** * Remove unused line numbers, i.e. line numbers where there is no code. * * @param method the method to be updated */ private void removeUselessLineNumbers(MethodNode method) { // Remove all the line numbers being immediately followed by another line number. for (ListIterator<?> lit = method.instructions.iterator(); lit.hasNext(); ) { AbstractInsnNode insn = (AbstractInsnNode) lit.next(); if (insn.getType() == LINE) { AbstractInsnNode nextInsn = insn.getNext(); if (nextInsn != null && nextInsn.getType() == LINE) { lit.remove(); } } } } private boolean analyseIfTestInt(MethodNode method, AbstractInsnNode insn) { return analyseIfTestInt(method, insn, insn, null); } private boolean analyseIfTestInt(MethodNode method, AbstractInsnNode insn, AbstractInsnNode valueInsn, Integer testValue) { boolean eliminateJump = false; AbstractInsnNode nextInsn = valueInsn.getNext(); if (nextInsn != null && nextInsn.getType() == JUMP_INSN) { JumpInsnNode jumpInsn = (JumpInsnNode) nextInsn; boolean doJump = false; switch (jumpInsn.getOpcode()) { case Opcodes.IFEQ: if (testValue == null && isIntValue(value)) { doJump = getIntValue(value) == 0; eliminateJump = true; } break; case Opcodes.IFNE: if (testValue == null && isIntValue(value)) { doJump = getIntValue(value) != 0; eliminateJump = true; } break; case Opcodes.IFLT: if (testValue == null && isIntValue(value)) { doJump = getIntValue(value) < 0; eliminateJump = true; } break; case Opcodes.IFGE: if (testValue == null && isIntValue(value)) { doJump = getIntValue(value) >= 0; eliminateJump = true; } break; case Opcodes.IFGT: if (testValue == null && isIntValue(value)) { doJump = getIntValue(value) > 0; eliminateJump = true; } break; case Opcodes.IFLE: if (testValue == null && isIntValue(value)) { doJump = getIntValue(value) <= 0; eliminateJump = true; } break; case Opcodes.IF_ICMPEQ: if (testValue != null && isIntValue(value)) { doJump = getIntValue(value) == testValue.intValue(); eliminateJump = true; } break; case Opcodes.IF_ICMPNE: if (testValue != null && isIntValue(value)) { doJump = getIntValue(value) != testValue.intValue(); eliminateJump = true; } break; case Opcodes.IF_ICMPLT: if (testValue != null && isIntValue(value)) { doJump = getIntValue(value) < testValue.intValue(); eliminateJump = true; } break; case Opcodes.IF_ICMPGE: if (testValue != null && isIntValue(value)) { doJump = getIntValue(value) >= testValue.intValue(); eliminateJump = true; } break; case Opcodes.IF_ICMPGT: if (testValue != null && isIntValue(value)) { doJump = getIntValue(value) > testValue.intValue(); eliminateJump = true; } break; case Opcodes.IF_ICMPLE: if (testValue != null && isIntValue(value)) { doJump = getIntValue(value) <= testValue.intValue(); eliminateJump = true; } break; } if (eliminateJump) { if (doJump) { // Replace the expression test by a fixed GOTO. // The skipped instructions will be eliminated by dead code analysis. method.instructions.set(insn, new JumpInsnNode(Opcodes.GOTO, jumpInsn.label)); } else { method.instructions.remove(insn); } deleteUpToInsn = jumpInsn.getNext(); } } return eliminateJump; } private boolean analyseIfTestFloat(MethodNode method, AbstractInsnNode insn, AbstractInsnNode valueInsn, float testValue) { boolean eliminateJump = false; AbstractInsnNode nextInsn = valueInsn.getNext(); if (nextInsn != null && (nextInsn.getOpcode() == Opcodes.FCMPL || nextInsn.getOpcode() == Opcodes.FCMPG)) { AbstractInsnNode nextNextInsn = nextInsn.getNext(); if (nextNextInsn != null && nextNextInsn.getType() == JUMP_INSN) { JumpInsnNode jumpInsn = (JumpInsnNode) nextNextInsn; boolean doJump = false; switch (jumpInsn.getOpcode()) { case Opcodes.IFEQ: if (isFloatValue(value)) { doJump = getFloatValue(value) == testValue; eliminateJump = true; } break; case Opcodes.IFNE: if (isFloatValue(value)) { doJump = getFloatValue(value) != testValue; eliminateJump = true; } break; case Opcodes.IFLT: if (isFloatValue(value)) { doJump = getFloatValue(value) < testValue; eliminateJump = true; } break; case Opcodes.IFGE: if (isFloatValue(value)) { doJump = getFloatValue(value) >= testValue; eliminateJump = true; } break; case Opcodes.IFGT: if (isFloatValue(value)) { doJump = getFloatValue(value) > testValue; eliminateJump = true; } break; case Opcodes.IFLE: if (isFloatValue(value)) { doJump = getFloatValue(value) <= testValue; eliminateJump = true; } break; } if (eliminateJump) { if (doJump) { // Replace the expression test by a fixed GOTO. // The skipped instructions will be eliminated by dead code analysis. method.instructions.set(insn, new JumpInsnNode(Opcodes.GOTO, jumpInsn.label)); } else { method.instructions.remove(insn); } deleteUpToInsn = jumpInsn.getNext(); } } } return eliminateJump; } private boolean isIntValue(Object value) { return (value instanceof Integer) || (value instanceof Boolean); } private int getIntValue(Object value) { if (value instanceof Integer) { return ((Integer) value).intValue(); } if (value instanceof Boolean) { return value == Boolean.FALSE ? 0 : 1; } return 0; } private boolean isFloatValue(Object value) { return (value instanceof Float); } private float getFloatValue(Object value) { if (value instanceof Float) { return ((Float) value).floatValue(); } return 0f; } private AbstractInsnNode getConstantInsn(Object value) { AbstractInsnNode constantInsn = null; if (isIntValue(value)) { int n = getIntValue(value); // Find the optimum opcode to represent this integer value switch (n) { case -1: constantInsn = new InsnNode(Opcodes.ICONST_M1); break; case 0: constantInsn = new InsnNode(Opcodes.ICONST_0); break; case 1: constantInsn = new InsnNode(Opcodes.ICONST_1); break; case 2: constantInsn = new InsnNode(Opcodes.ICONST_2); break; case 3: constantInsn = new InsnNode(Opcodes.ICONST_3); break; case 4: constantInsn = new InsnNode(Opcodes.ICONST_4); break; case 5: constantInsn = new InsnNode(Opcodes.ICONST_5); break; default: if (Byte.MIN_VALUE <= n && n < Byte.MAX_VALUE) { constantInsn = new IntInsnNode(Opcodes.BIPUSH, n); } else if (Short.MIN_VALUE <= n && n < Short.MAX_VALUE) { constantInsn = new IntInsnNode(Opcodes.SIPUSH, n); } else { constantInsn = new LdcInsnNode(new Integer(n)); } break; } } return constantInsn; } } private static class SpecializedClassLoader extends ClassLoader { public Class<?> defineClass(String name, byte[] b) { return defineClass(name, b, 0, b.length); } } }