package org.enumerable.lambda.support.expression;
import japa.parser.ast.expr.Expression;
import org.enumerable.lambda.Fn0;
import org.enumerable.lambda.Fn1;
import org.enumerable.lambda.Fn2;
import org.enumerable.lambda.Fn3;
import org.enumerable.lambda.annotation.LambdaLocal;
import org.enumerable.lambda.weaving.InMemoryCompiler;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.analysis.Analyzer;
import org.objectweb.asm.tree.analysis.Frame;
import org.objectweb.asm.util.ASMifierMethodVisitor;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Reader;
import java.io.StringReader;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import static org.enumerable.lambda.exception.UncheckedException.uncheck;
import static org.objectweb.asm.Type.*;
public class LambdaExpressionTrees {
static int expressionId = 1;
static InMemoryCompiler compiler = new InMemoryCompiler();
public static Expression parseExpression(String expression) {
try {
Class<?> parserClass = Class.forName("japa.parser.ASTParser");
Constructor<?> ctor = parserClass.getConstructor(Reader.class);
ctor.setAccessible(true);
Object parser = ctor.newInstance(new StringReader(expression));
Method method = parserClass.getMethod("Expression");
method.setAccessible(true);
return (Expression) method.invoke(parser);
} catch (Exception e) {
throw uncheck(e);
}
}
static Expression parseExpressionFromSingleMethodClass(Class<?> aClass, String... parameters) {
return parseExpressionFromMethod(Fn0.getLambdaMethod(aClass), parameters);
}
public static Expression parseExpressionFromMethod(Method method, String... parameters) {
try {
MethodNode mn = findMethodNode(method);
LocalVariableNode[] parameterLocals = new LocalVariableNode[parameters.length];
Type[] argumentTypes = getArgumentTypes(mn.desc);
int realIndex = 1;
for (int i = 0; i < parameters.length; i++) {
parameterLocals[i] = new LocalVariableNode(parameters[i], argumentTypes[i].getDescriptor(), null,
null, null, realIndex);
realIndex += argumentTypes[i].getSize();
}
final ExpressionInterpreter interpreter = new ExpressionInterpreter(mn, parameterLocals);
Analyzer analyzer = new Analyzer(interpreter) {
protected Frame newFrame(Frame src) {
Frame frame = super.newFrame(src);
interpreter.setCurrentFrame(frame);
return frame;
}
protected void newControlFlowEdge(int insn, int successor) {
interpreter.newControlFlowEdge(insn, successor);
}
};
interpreter.analyzer = analyzer;
analyzer.analyze(getInternalName(method.getDeclaringClass()), mn);
return interpreter.expression;
} catch (Exception e) {
throw uncheck(e);
}
}
static void printASMifiedMethod(Method method) {
try {
MethodNode mn = findMethodNode(method);
ASMifierMethodVisitor asm = new ASMifierMethodVisitor();
mn.accept(asm);
PrintWriter pw = new PrintWriter(System.out);
asm.print(pw);
pw.flush();
} catch (Exception e) {
throw uncheck(e);
}
}
@SuppressWarnings("unchecked")
static MethodNode findMethodNode(Method method) throws IOException {
String className = method.getDeclaringClass().getName();
ClassReader cr;
if (InMemoryCompiler.bytesByClassName.containsKey(className))
cr = new ClassReader(InMemoryCompiler.bytesByClassName.get(className));
else
cr = new ClassReader(className);
ClassNode cn = new ClassNode();
cr.accept(cn, 0);
String descriptor = getMethodDescriptor(method);
for (MethodNode mn : (List<MethodNode>) cn.methods) {
if (method.getName().equals(mn.name) && descriptor.equals(mn.desc))
return mn;
}
throw new IllegalStateException("Cannot find method which does exist");
}
@SuppressWarnings("unchecked")
public static <R extends Expression> R toExpression(Fn0<?> fn) {
if (fn.getClass().getDeclaredFields().length > 0)
throw new IllegalArgumentException("Turning Closures into Expressions isn't supported");
List<LambdaLocal> parameters = fn.getParameters();
String[] parameterNames = new String[parameters.size()];
for (int i = 0; i < parameters.size(); i++)
parameterNames[i] = parameters.get(i).name();
return (R) parseExpressionFromSingleMethodClass(fn.getClass(), parameterNames);
}
public static <R> Fn0<R> toFn0(Class<R> returnType, Expression expression) {
try {
String className = "ExpressionFn0_" + expressionId++;
String source = "class " + className + " extends " + Fn0.class.getName() + "{ public "
+ typeToString(returnType) + " call() { return " + expression + "; }}";
return compileAndCreate(className, source);
} catch (Exception e) {
throw uncheck(e);
}
}
public static <A1, R> Fn1<A1, R> toFn1(Class<R> returnType, Class<A1> a1Type, String a1Name,
Expression expression) {
try {
String className = "ExpressionFn1_" + expressionId++;
String source = "class " + className + " extends " + Fn1.class.getName() + "{ public "
+ typeToString(returnType) + " call(" + typeToString(a1Type) + " " + a1Name + ") { return "
+ expression + "; } public " + typeToString(returnType) + " call(Object " + a1Name
+ ") { return call((" + typeToString(a1Type) + ") " + a1Name + "); } }";
return compileAndCreate(className, source);
} catch (Exception e) {
throw uncheck(e);
}
}
public static <A1, A2, R> Fn2<A1, A2, R> toFn2(Class<R> returnType, Class<A1> a1Type, String a1Name,
Class<A2> a2Type, String a2Name, Expression expression) {
try {
String className = "ExpressionFn2_" + expressionId++;
String source = "class " + className + " extends " + Fn2.class.getName() + "{ public "
+ typeToString(returnType) + " call(" + typeToString(a1Type) + " " + a1Name + ", "
+ typeToString(a2Type) + " " + a2Name + ") { return " + expression + "; } public "
+ typeToString(returnType) + " call(Object " + a1Name + ", Object " + a2Name
+ ") { return call((" + typeToString(a1Type) + ") " + a1Name + ", (" + typeToString(a2Type)
+ ") " + a2Name + "); } }";
return compileAndCreate(className, source);
} catch (Exception e) {
throw uncheck(e);
}
}
public static <A1, A2, A3, R> Fn3<A1, A2, A3, R> toFn3(Class<R> returnType, Class<A1> a1Type, String a1Name,
Class<A2> a2Type, String a2Name, Class<A2> a3Type, String a3Name, Expression expression) {
try {
String className = "ExpressionFn3_" + expressionId++;
String source = "class " + className + " extends " + Fn3.class.getName() + "{ public "
+ typeToString(returnType) + " call(" + typeToString(a1Type) + " " + a1Name + ", "
+ typeToString(a2Type) + " " + a2Name + ", " + typeToString(a3Type) + " " + a3Name
+ ") { return " + expression + "; } public " + typeToString(returnType) + " call(Object "
+ a1Name + ", Object " + a2Name + ", Object " + a3Name + ") { return call(("
+ typeToString(a1Type) + ") " + a1Name + ", (" + typeToString(a2Type) + ") " + a2Name + ", ("
+ typeToString(a3Type) + ") " + a3Name + "); } }";
return compileAndCreate(className, source);
} catch (Exception e) {
throw uncheck(e);
}
}
static String typeToString(Class<?> returnType) {
return returnType.isArray() ? returnType.getComponentType().getName() + "[]" : returnType.getName();
}
@SuppressWarnings("unchecked")
static <R extends Fn0<?>> R compileAndCreate(String className, String source) throws IOException,
InstantiationException, IllegalAccessException, InvocationTargetException {
Class<?> aClass = compiler.compile(className, source);
Constructor<?> ctor = aClass.getDeclaredConstructors()[0];
ctor.setAccessible(true);
return (R) ctor.newInstance();
}
}