package jef.codegen;
import java.io.File;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import javax.persistence.ManyToMany;
import javax.persistence.ManyToOne;
import javax.persistence.OneToMany;
import javax.persistence.OneToOne;
import jef.accelerator.asm.ASMUtils;
import jef.accelerator.asm.Attribute;
import jef.accelerator.asm.ClassReader;
import jef.accelerator.asm.ClassVisitor;
import jef.accelerator.asm.ClassWriter;
import jef.accelerator.asm.FieldVisitor;
import jef.accelerator.asm.Label;
import jef.accelerator.asm.MethodVisitor;
import jef.accelerator.asm.Opcodes;
import jef.accelerator.asm.Type;
import jef.accelerator.asm.commons.AnnotationDef;
import jef.accelerator.asm.commons.FieldExtCallback;
import jef.accelerator.asm.commons.FieldExtDef;
import jef.tools.Assert;
import jef.tools.IOUtils;
import jef.tools.StringUtils;
import org.apache.commons.lang.ArrayUtils;
public class EnhanceTaskASM {
private File root;
public EnhanceTaskASM(File root, File[] roots) {
super();
this.root = root;
}
public EnhanceTaskASM() {
}
/**
*
* @param classdata
* @param fieldEumData
* 允许传入null
* @return 返回null表示不需要增强,返回byte[0]表示该类已经增强,返回其他数据为增强后的class
* @throws Exception
*/
public byte[] doEnhance(byte[] classdata, byte[] fieldEumData) throws Exception {
Assert.notNull(classdata);
List<String> enumFields = parseEnumFields(fieldEumData);
try {
ClassReader reader = new ClassReader(classdata);
byte[] data = enhanceClass(reader, enumFields);
// {
// DEBUG
// File file = new File("c:/asm/" +
// StringUtils.substringAfterLast(className, ".") + ".class");
// IOUtils.saveAsFile(file, data);
// System.out.println(file +
// " saved -- Enhanced class"+className);
// }
return data;
} catch (EnhancedException e) {
return ArrayUtils.EMPTY_BYTE_ARRAY;
}
}
public List<String> parseEnumFields(byte[] fieldEumData) {
final List<String> enumFields = new ArrayList<String>();
if (fieldEumData != null) {
ClassReader reader = new ClassReader(fieldEumData);
reader.accept(new ClassVisitor(Opcodes.ASM5) {
@Override
public FieldVisitor visitField(int access, String name, String desc, String sig, Object value) {
if ((access & Opcodes.ACC_ENUM) > 0) {
enumFields.add(name);
}
return null;
}
}, ClassReader.SKIP_CODE);
}
return enumFields;
}
private static class EnhancedException extends RuntimeException {
private static final long serialVersionUID = 1L;
@Override
public synchronized Throwable fillInStackTrace() {
return this;
}
}
public byte[] enhanceClass(ClassReader reader, final List<String> enumFields) {
if ((reader.getAccess() & Opcodes.ACC_PUBLIC) == 0) {
return null;// 非公有跳过
}
boolean isEntityInterface = isEntityClass(reader.getInterfaces(), reader.getSuperName(),!enumFields.isEmpty());
if (!isEntityInterface)
return null;
ClassWriter cw = new ClassWriter(0);
reader.accept(new ClassVisitor(Opcodes.ASM5,cw) {
private List<String> nonStaticFields = new ArrayList<String>();
private List<String> lobAndRefFields = new ArrayList<String>();
private String typeName;
@Override
public void visit(int version, int access, String name, String sig, String superName, String[] interfaces) {
this.typeName = name.replace('.', '/');
// if(version==Opcodes.V1_7){
version=Opcodes.V1_6;
// }
super.visit(version, access, name, sig, superName, interfaces);
}
@Override
public void visitAttribute(Attribute attr) {
if ("jefd".equals(attr.type)) {
throw new EnhancedException();
}
super.visitAttribute(attr);
}
@Override
public void visitEnd() {
Attribute attr = new Attribute("jefd",new byte[] { 0x1f });
super.visitAttribute(attr);
}
@Override
public FieldVisitor visitField(final int access, final String name, final String desc, String sig, final Object value) {
FieldVisitor visitor = super.visitField(access, name, desc, sig, value);
if ((access & Opcodes.ACC_STATIC) > 0)
return visitor;
nonStaticFields.add(name);
return new FieldExtDef(Opcodes.ASM5,new FieldExtCallback(visitor) {
public void onFieldRead(FieldExtDef info) {
boolean contains = enumFields.contains(name);
if (contains) {
AnnotationDef annotation = info.getAnnotation("Ljavax/persistence/Lob;");
if (annotation != null) {
lobAndRefFields.add(name);
}
} else {
Object o = null;
if (o == null)
o = info.getAnnotation(OneToMany.class);
if (o == null)
o = info.getAnnotation(ManyToOne.class);
if (o == null)
o = info.getAnnotation(ManyToMany.class);
if (o == null)
o = info.getAnnotation(OneToOne.class);
if (o != null) {
lobAndRefFields.add(name);
}
}
}
});
}
@Override
public MethodVisitor visitMethod(int access, String name, String desc, String sig, String[] exceptions) {
String fieldName;
if (name.startsWith("get")) {
fieldName = StringUtils.uncapitalize(name.substring(3));
return asGetter(fieldName, access, name, desc, exceptions, sig);
} else if (name.startsWith("is")) {
fieldName = StringUtils.uncapitalize(name.substring(2));
return asGetter(fieldName, access, name, desc, exceptions, sig);
} else if (name.startsWith("set")) {
fieldName = StringUtils.uncapitalize(name.substring(3));
return asSetter(fieldName, access, name, desc, exceptions, sig);
}
return super.visitMethod(access, name, desc, sig, exceptions);
}
private MethodVisitor asGetter(String fieldName, int access, String name, String desc, String[] exceptions, String sig) {
MethodVisitor mv = super.visitMethod(access, name, desc, sig, exceptions);
Type[] types = Type.getArgumentTypes(desc);
if (fieldName.length() == 0 || types.length > 0)
return mv;
if (lobAndRefFields.contains(fieldName)) {
return new GetterVisitor(mv, fieldName, typeName);
}
return mv;
}
private MethodVisitor asSetter(String fieldName, int access, String name, String desc, String[] exceptions, String sig) {
MethodVisitor mv = super.visitMethod(access, name, desc, sig, exceptions);
Type[] types = Type.getArgumentTypes(desc);
if (fieldName.length() == 0 || types.length != 1)
return mv;
if (enumFields.contains(fieldName) && nonStaticFields.contains(fieldName)) {
return new SetterVisitor(mv, fieldName, typeName, types[0]);
}else if(lobAndRefFields.contains(fieldName)) {
return new SetterOfClearLazyload(mv, fieldName, typeName);
}else{
String altFieldName="is"+StringUtils.capitalize(fieldName);
//特定情况,当boolean类型并且field名称是isXXX,setter是setXXX()
if(enumFields.contains(altFieldName)){
return new SetterVisitor(mv, altFieldName, typeName, types[0]);
}
}
return mv;
}
}, 0);
return cw.toByteArray();
}
private boolean isEntityClass(String[] interfaces, String superName,boolean defaultValue) {
if ("jef/database/DataObject".equals(superName))
return true;// 绝大多数实体都是继承这个类的
if (ArrayUtils.contains(interfaces, "Ljef/database/IQueryableEntity;")) {
return true;
}
if ("java/lang/Object".equals(superName)) {
return false;
}
// 递归检查父类
ClassReader cl = null;
try {
URL url = ClassLoader.getSystemResource(superName + ".class");
if (url == null && root!=null) {
File parent = null;
if (root.exists()) {
parent = new File(root, superName + ".class");
}
// if(!parent.exists()){
// for(File roo:roots){
// parent = new File(roo, superName + ".class");
// if(parent.exists())break;
// }
// }
if(parent.exists()){
url=parent.toURI().toURL();
}
}
if(url==null){ //父类找不到,无法准确判断
return defaultValue;
}
byte[] parent=IOUtils.toByteArray(url);
cl = new ClassReader(parent);
} catch (Exception e) {
e.printStackTrace();
}
if (cl != null) {
return isEntityClass(cl.getInterfaces(), cl.getSuperName(),defaultValue);
}
return false;
}
// public byte[] getBinaryData_x();
// Code:
// 0: aload_0
// 1: ldc #117; //String binaryData
// 3: invokevirtual #118; //Method beforeGet:(Ljava/lang/String;)V
// 6: aload_0
// 7: getfield #121; //Field binaryData:[B
// 10: areturn
static class GetterVisitor extends MethodVisitor implements Opcodes {
private String name;
private String typeName;
public GetterVisitor(MethodVisitor mv, String name, String typeName) {
super(Opcodes.ASM5,mv);
this.name = name;
this.typeName = typeName;
}
public void visitCode() {
mv.visitIntInsn(ALOAD,0);
mv.visitLdcInsn(name);
mv.visitMethodInsn(INVOKEVIRTUAL, typeName, "beforeGet", "(Ljava/lang/String;)V",false);
super.visitCode();
}
@Override
public void visitMaxs(int maxStack, int maxLocals) {
mv.visitMaxs(2, maxLocals);
}
// 去除本地变量表。
@Override
public void visitLocalVariable(String name, String desc, String signature, Label start, Label end, int index) {
}
}
static class SetterOfClearLazyload extends MethodVisitor implements Opcodes {
private String name;
private String typeName;
public SetterOfClearLazyload(MethodVisitor mv, String name, String typeName) {
super(Opcodes.ASM5,mv);
this.name = name;
this.typeName = typeName;
}
// 去除本地变量表。否则生成的类用jd-gui反编译时,添加的代码段无法正常反编译
@Override
public void visitLocalVariable(String name, String desc, String signature, Label start, Label end, int index) {
}
public void visitCode() {
mv.visitIntInsn(ALOAD,0);
mv.visitLdcInsn(name);
mv.visitMethodInsn(INVOKEVIRTUAL, typeName, "beforeSet", "(Ljava/lang/String;)V",false);
super.visitCode();
}
@Override
public void visitMaxs(int maxStack, int maxLocals) {
mv.visitMaxs(4, maxLocals);
}
}
//
// public void setBinaryData_x(byte[]);
// Code:
// 0: aload_0
// 1: getfield #125; //Field _recordUpdate:Z
// 4: ifeq 16
// 7: aload_0
// 8: getstatic #128; //Field
// jef/orm/onetable/model/TestEntity$Field.binaryData:Ljef/orm/onetable/model/TestEntity$Field;
// 11: aload_1
// 13: invokevirtual #133; //Method
// prepareUpdate:(Ljef/database/Field;Ljava/lang/Object;)V
// 16: aload_0
// 17: aload_1
// 18: putfield #121; //Field binaryData:[B
// 21: return
static class SetterVisitor extends MethodVisitor implements Opcodes {
private String name;
private String typeName;
private Type paramType;
public SetterVisitor(MethodVisitor mv, String name, String typeName, Type paramType) {
super(Opcodes.ASM5,mv);
this.name = name;
this.typeName = typeName;
this.paramType = paramType;
}
// 去除本地变量表。否则生成的类用jd-gui反编译时,添加的代码段无法正常反编译
@Override
public void visitLocalVariable(String name, String desc, String signature, Label start, Label end, int index) {
}
public void visitCode() {
mv.visitIntInsn(ALOAD,0);
mv.visitFieldInsn(GETFIELD, typeName, "_recordUpdate", "Z");
Label norecord = new Label();
mv.visitJumpInsn(IFEQ, norecord);
mv.visitIntInsn(ALOAD,0);
mv.visitFieldInsn(GETSTATIC, typeName + "$Field", name, "L" + typeName + "$Field;");
if (paramType.isPrimitive()) {
mv.visitVarInsn(ASMUtils.getLoadIns(paramType), 1);
ASMUtils.doWrap(mv, paramType);
} else {
mv.visitIntInsn(ALOAD,1);
}
mv.visitMethodInsn(INVOKEVIRTUAL, typeName, "prepareUpdate", "(Ljef/database/Field;Ljava/lang/Object;)V",false);
mv.visitLabel(norecord);
super.visitCode();
}
@Override
public void visitMaxs(int maxStack, int maxLocals) {
mv.visitMaxs(4, maxLocals);
}
}
}