package org.n3r.eql.eqler.generators;
import org.n3r.eql.Eql;
import org.n3r.eql.EqlTran;
import org.n3r.eql.eqler.annotations.EqlerConfig;
import org.n3r.eql.trans.EqlTranThreadLocal;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import java.lang.reflect.Method;
import static org.n3r.eql.util.Asms.p;
import static org.objectweb.asm.Opcodes.*;
public class TranableMethodGenerator<T> {
private final String methodName;
private final String eqlClassName;
private final EqlerConfig eqlerConfig;
private final ClassWriter cw;
public TranableMethodGenerator(ClassWriter classWriter, Method method, Class<T> eqlerClass) {
this.methodName = method.getName();
this.cw = classWriter;
EqlerConfig eqlerConfig = eqlerClass.getAnnotation(EqlerConfig.class);
this.eqlerConfig = eqlerConfig != null ?
eqlerConfig : eqlerClass.getAnnotation(EqlerConfig.class);
this.eqlClassName = eqlerConfig != null ?
Type.getInternalName(eqlerConfig.eql()) : p(Eql.class);
}
public void generate() {
MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, methodName, "()V", null, null);
mv.visitCode();
if ("start".equals(methodName)) {
start(mv);
} else if ("commit".equals(methodName) || "rollback".equals(methodName)) {
commitOrRollback(mv, methodName);
} else if ("close".equals(methodName)) {
close(mv);
}
mv.visitInsn(RETURN);
mv.visitMaxs(-1, -1);
mv.visitEnd();
}
private void close(MethodVisitor mv) {
mv.visitMethodInsn(INVOKESTATIC, p(EqlTranThreadLocal.class), "get",
"()Lorg/n3r/eql/EqlTran;", false);
mv.visitVarInsn(ASTORE, 1);
mv.visitVarInsn(ALOAD, 1);
Label l0 = new Label();
mv.visitJumpInsn(IFNULL, l0);
mv.visitMethodInsn(INVOKESTATIC, p(EqlTranThreadLocal.class), "clear", "()V", false);
mv.visitVarInsn(ALOAD, 1);
mv.visitMethodInsn(INVOKEINTERFACE, p(EqlTran.class), "close", "()V", true);
mv.visitLabel(l0);
mv.visitFrame(F_APPEND, 1, new Object[]{p(EqlTran.class)}, 0, null);
}
private void commitOrRollback(MethodVisitor mv, String methodName) {
mv.visitMethodInsn(INVOKESTATIC, p(EqlTranThreadLocal.class), "get",
"()Lorg/n3r/eql/EqlTran;", false);
mv.visitVarInsn(ASTORE, 1);
mv.visitVarInsn(ALOAD, 1);
Label l0 = new Label();
mv.visitJumpInsn(IFNULL, l0);
mv.visitVarInsn(ALOAD, 1);
mv.visitMethodInsn(INVOKEINTERFACE, p(EqlTran.class), methodName, "()V", true);
mv.visitLabel(l0);
mv.visitFrame(F_APPEND, 1, new Object[]{p(EqlTran.class)}, 0, null);
}
private void newEql(MethodVisitor mv) {
mv.visitTypeInsn(NEW, eqlClassName);
mv.visitInsn(DUP);
mv.visitLdcInsn(eqlerConfig != null ? eqlerConfig.value() : "DEFAULT");
mv.visitMethodInsn(INVOKESPECIAL, eqlClassName, "<init>",
"(Ljava/lang/String;)V", false);
mv.visitMethodInsn(INVOKEVIRTUAL, eqlClassName, "me",
"()Lorg/n3r/eql/Eql;", false);
}
private void start(MethodVisitor mv) {
mv.visitMethodInsn(INVOKESTATIC, p(EqlTranThreadLocal.class), "get",
"()Lorg/n3r/eql/EqlTran;", false);
mv.visitVarInsn(ASTORE, 1);
mv.visitVarInsn(ALOAD, 1);
Label l0 = new Label();
mv.visitJumpInsn(IFNULL, l0);
mv.visitInsn(RETURN);
mv.visitLabel(l0);
mv.visitFrame(F_APPEND, 1, new Object[]{p(EqlTran.class)}, 0, null);
newEql(mv);
mv.visitMethodInsn(INVOKEVIRTUAL, p(Eql.class), "newTran",
"()Lorg/n3r/eql/EqlTran;", false);
mv.visitVarInsn(ASTORE, 1);
mv.visitVarInsn(ALOAD, 1);
mv.visitMethodInsn(INVOKESTATIC, p(EqlTranThreadLocal.class), "set",
"(Lorg/n3r/eql/EqlTran;)V", false);
}
public static boolean isEqlTranableMethod(Method method) {
if ("()V".equals(Type.getMethodDescriptor(method))) {
String name = method.getName();
return "start".equals(name) ||
"commit".equals(name) ||
"rollback".equals(name) ||
"close".equals(name);
}
return false;
}
}