package com.github.forax.smartass.rt;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodHandles.Lookup;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Collectors;import java.util.stream.IntStream;
import java.util.stream.Stream;
final class JavaBridge {
static Klass javaClasstoKlass(Class<?> type, ClassValue<Klass> klasses) {
// Java class, not a class generated by the runtime
String klassName = type.getName();
HashMap<String, Function[]> classFunMap = new HashMap<>();
HashMap<String, Function[]> staticFunMap = new HashMap<>();
Klass staticKlass = Klass.create("static-" + klassName, null, Collections.emptyList(), staticFunMap);
staticKlass.registerInitializer(klazz -> {
HashMap<String, ArrayList<Executable>[]> staticMap = new HashMap<>();
gatherMethods(Arrays.<Executable>stream(type.getMethods()).filter(JavaBridge::isStatic), staticMap);
populateKlassWithStaticFields(type, staticFunMap);
populateKlassWithMethods(staticFunMap, staticMap);
Klass klassKlass = klasses.get(Klass.class);
klassKlass.initialize();
staticFunMap.putAll(klassKlass.getMethodMap());
});
Klass klass = Klass.create(klassName, staticKlass, Collections.emptyList(), classFunMap);
klass.registerInitializer(klazz -> {
HashMap<String, ArrayList<Executable>[]> classMap = new HashMap<>();
gatherMethods(Arrays.<Executable>stream(type.getConstructors()), classMap);
gatherMethods(Arrays.<Executable>stream(type.getMethods()).filter(JavaBridge::isNotStatic), classMap);
populateKlassWithMethods(classFunMap, classMap);
});
return klass;
}
private static boolean isStatic(Executable executable) {
return Modifier.isStatic(executable.getModifiers());
}
private static boolean isNotStatic(Executable executable) {
return !Modifier.isStatic(executable.getModifiers());
}
private static void populateKlassWithStaticFields(Class<?> type,
HashMap<String, Function[]> staticFunMap) {
for(Field field: type.getFields()) {
int modifiers = field.getModifiers();
if (!(Modifier.isStatic(modifiers) || Modifier.isFinal(modifiers))) {
continue;
}
//System.out.println("register field " + field.getName());
staticFunMap.put(field.getName(),
new Function[]{ Function.createFromMH(Collections.emptyList(),
fun -> {
MethodHandle mh;
try {
mh = MethodHandles.publicLookup().unreflectGetter(field);
} catch (IllegalAccessException e) {
mh = MethodHandles.throwException(Object.class, LinkageError.class);
}
mh = mh.asType(mh.type().generic());
// a static method should skip the first parameter
return MethodHandles.dropArguments(mh, 0, Object.class);
})});
}
}
private static void populateKlassWithMethods(HashMap<String, Function[]> funMap, HashMap<String, ArrayList<Executable>[]> classMap) {
classMap.forEach((name, methodByArity) -> {
int maxArity = methodByArity.length;
Function[] functions = new Function[maxArity];
for(int i = 0; i < maxArity; i++) {
ArrayList<Executable> executables = methodByArity[i];
if (executables == null) {
continue;
}
Executable executable = executables.get(0);
List<String> params = Arrays.stream(executable.getParameters()).map(Parameter::getName).collect(Collectors.toList());
Function function = Function.createFromMH(params,
fun -> {
if (executables.size() == 1) {
return createMH(executables.get(0));
}
Class<?>[][] parameters = executables.stream()
.map((java.util.function.Function<Executable,Class<?>[]>)Executable::getParameterTypes)
.toArray(Class<?>[][]::new);
//System.out.println("create decision tree " + methods);
MethodHandle[] mhs = executables.stream().map(JavaBridge::createMH).toArray(MethodHandle[]::new);
int[] indexes = IntStream.range(0, mhs.length).toArray();
return createDecisionTree(parameters[0].length, parameters, 0, mhs, indexes, 0, indexes.length - 1);
});
function.setNameHint(name);
functions[i] = function;
}
funMap.put(name, functions);
});
}
private static MethodHandle createDecisionTree(int length, Class<?>[][] parameters, int parameterIndex,
MethodHandle[] mhs, int[] indexes, int start, int end) {
HashSet<Class<?>> typeSet = new HashSet<>();
for(int j = 0; j < parameters.length; j++) {
typeSet.add(parameters[j][parameterIndex]);
}
if (typeSet.size() == 1) { // parameter types are all the same, look for the next parameter
return createDecisionTree(length, parameters, parameterIndex + 1, mhs, indexes, start, end);
}
Class<?>[] types = typeSet.toArray(new Class<?>[typeSet.size()]);
Arrays.sort(types, JavaBridge::compareByType);
//System.out.println("types sorted " + Arrays.toString(types));
MethodHandle[] targets = new MethodHandle[types.length];
for(int j = types.length; --j >=0;) {
Class<?> type = types[j];
//System.out.println("pick " + type);
int last = end;
for(int i = start; i <= last;) {
int index = indexes[i];
Class<?> methodParameterType = parameters[index][parameterIndex];
//System.out.println("type check: " + type + ' ' + methodParameterType + ": " + isSubTypeOf(type, methodParameterType));
if (type == methodParameterType) {
i++;
continue;
}
// swap index at i and last
indexes[i] = indexes[last];
indexes[last] = index;
last--;
}
MethodHandle mh;
if (start == last) {
mh = mhs[indexes[start]];
} else {
mh = createDecisionTree(length, parameters, parameterIndex + 1, mhs, indexes, start, last);
}
targets[j] = mh;
start = last + 1;
}
return createTreeFromMHs(parameterIndex, types, targets);
}
private static MethodHandle createTreeFromMHs(int parameterIndex, Class<?>[] types, MethodHandle[] targets) {
MethodHandle tree = targets[0];
for(int i = 1; i<types.length; i++) {
Class<?> type = types[i];
if (type.isPrimitive()) {
type = boxedType(type);
}
MethodHandle test = MethodHandles.dropArguments(IS_INSTANCE.bindTo(type), 0, Collections.nCopies(1 + parameterIndex, Object.class));
tree = MethodHandles.guardWithTest(test, targets[i], tree);
}
return tree;
}
private static int compareByType(Class<?> type1, Class<?> type2) {
assert type1 != type2; // the types must be different
if (type1.isPrimitive()) {
if (type2.isPrimitive()) {
if (type1 != boolean.class && type2 != boolean.class) {
int ordinal1 = ordinalType(type1);
int ordinal2 = ordinalType(type2);
if (ordinal1 != ordinal2) { // not comparable like short and char
return ordinal1 - ordinal2;
}
}
} else {
if (type2 == Object.class) {
return 1;
}
}
} else {
if (type1 == Object.class && type2.isPrimitive()) {
return -1;
}
}
if (type1.isAssignableFrom(type2)) {
return -1;
}
if (type2.isAssignableFrom(type1)) {
return 1;
}
return type1.getName().compareTo(type2.getName()); // make it stable
}
private static int ordinalType(Class<?> type) {
if (type == int.class) {
return 4;
}
if (type == double.class) {
return 1;
}
if (type == char.class) {
return 5;
}
if (type == byte.class) {
return 6;
}
if (type == long.class) {
return 3;
}
if (type == float.class) {
return 2;
}
return 5; // short
}
private static Class<?> boxedType(Class<?> type) {
if (type == int.class) {
return Integer.class;
}
if (type == double.class) {
return Double.class;
}
if (type == char.class) {
return Character.class;
}
if (type == byte.class) {
return Byte.class;
}
if (type == long.class) {
return Long.class;
}
if (type == float.class) {
return Float.class;
}
return Short.class; // short
}
private static MethodHandle createMH(Executable executable) {
int modifiers = executable.getModifiers();
boolean isStatic = Modifier.isStatic(modifiers);
Class<?> declaringClass = executable.getDeclaringClass();
MethodHandle mh;
try {
if (executable instanceof Method) {
Method method = (Method)executable;
if (isStatic || Modifier.isPublic(declaringClass.getModifiers())) {
mh = MethodHandles.publicLookup().unreflect(method);
if (isStatic) {
// a static method should skip the first parameter
mh = MethodHandles.dropArguments(mh, 0, Object.class);
}
} else {
mh = findVirtualMHInHierarchy(declaringClass, method.getName(),
MethodType.methodType(method.getReturnType(), method.getParameterTypes()));
}
} else {
Constructor<?> constructor = (Constructor<?>)executable;
mh = MethodHandles.publicLookup().unreflectConstructor(constructor);
// a constructor method should skip the first parameter
mh = MethodHandles.dropArguments(mh, 0, Object.class);
}
} catch (IllegalAccessException e) {
return willThrowAnException(executable, e);
}
return conversionFilter(mh);
}
private static MethodHandle willThrowAnException(Executable method, Exception e) {
Class<?> returnType = (method instanceof Method)? ((Method)method).getReturnType(): ((Constructor<?>)method).getDeclaringClass();
MethodHandle mh = MethodHandles.throwException(returnType, IllegalAccessException.class);
mh = mh.bindTo(e);
return MethodHandles.dropArguments(mh, 0,
MethodType.genericMethodType(1 + method.getParameterCount()).parameterList());
}
private static MethodHandle findVirtualMHInHierarchy(Class<?> declaringClass, String name, MethodType methodType) throws IllegalAccessException {
// first lookup in super class hierarchy
LinkedHashSet<Class<?>> interfaces = new LinkedHashSet<>();
Collections.addAll(interfaces, declaringClass.getInterfaces());
for(Class<?> type = declaringClass.getSuperclass(); type != null; type = type.getSuperclass()) {
if (Modifier.isPublic(type.getModifiers())) {
try {
return MethodHandles.publicLookup().findVirtual(type, name, methodType);
} catch(NoSuchMethodException e) {
Collections.addAll(interfaces, type.getInterfaces());
break;
}
}
Collections.addAll(interfaces, type.getInterfaces());
}
// not found in the super class hierarchy, try interface hierarchy
ArrayDeque<Class<?>> queue = new ArrayDeque<>(interfaces);
Class<?> interfaceType;
while((interfaceType = queue.poll()) != null) {
if (Modifier.isPublic(interfaceType.getModifiers())) {
try {
return MethodHandles.publicLookup().findVirtual(interfaceType, name, methodType);
} catch(NoSuchMethodException e) {
// do nothing
}
}
for(Class<?> iType: interfaceType.getInterfaces()) {
if (interfaces.add(iType) == true) { // only add unknown interface to the queue
queue.add(iType);
}
}
}
throw new IllegalAccessException("method " + declaringClass.getName() + '.' + name + methodType + " not visible");
}
private static final MethodHandle IS_INSTANCE;
private static final MethodHandle FUNCTION_TEST;
private static final MethodHandle PROXY_FUNCTION;
private static final MethodHandle IDENTITY;
static {
Lookup lookup = MethodHandles.lookup();
MethodHandle isInstance, proxyFunction;
try {
isInstance = lookup.findVirtual(Class.class, "isInstance",
MethodType.methodType(boolean.class, Object.class));
proxyFunction = lookup.findStatic(JavaBridge.class, "proxyFunction",
MethodType.methodType(Object.class, Class.class, Object.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new AssertionError(e);
}
PROXY_FUNCTION = proxyFunction.asType(MethodType.methodType(Object.class, Class.class, Object.class));
IS_INSTANCE = isInstance;
FUNCTION_TEST = isInstance.bindTo(Function.class);
IDENTITY = MethodHandles.identity(Object.class);
}
private static MethodHandle conversionFilter(MethodHandle mh) {
MethodType type = mh.type();
int parameterCount = type.parameterCount();
MethodHandle[] filters = new MethodHandle[parameterCount];
for(int i = 0; i < parameterCount; i++) {
Class<?> parameterType = type.parameterType(i);
if (parameterType.isInterface()) {
MethodHandle filter = MethodHandles.guardWithTest(FUNCTION_TEST, PROXY_FUNCTION.bindTo(parameterType), IDENTITY);
filters[i] = filter;
}
}
mh = mh.asType(type.generic());
return MethodHandles.filterArguments(mh, 0, filters);
}
@SuppressWarnings("unused") // called using a method handle
private static Object proxyFunction(Class<?> interfaceType, Object fun) {
Function function = (Function)fun;
return Proxy.newProxyInstance(interfaceType.getClassLoader(), new Class<?>[]{interfaceType},
(Object proxy, Method method, Object[] args) -> {
if (method.isDefault() || method.getDeclaringClass() == Object.class) {
return method.invoke(proxy, args);
}
Object[] arguments;
if (args == null) {
arguments = new Object[] { proxy };
} else {
arguments = new Object[1 + args.length];
arguments[0] = proxy;
System.arraycopy(args, 0, arguments, 1, args.length);
}
return function.getTarget().invokeWithArguments(arguments);
});
}
private static void gatherMethods(Stream<Executable> executables, HashMap<String, ArrayList<Executable>[]> map) {
executables.forEach(executable -> {
if (executable.getDeclaringClass() == Object.class) {
return; // skip java.lang.Object's methods, objects have no identity
}
String name;
MethodInfo methodInfo = executable.getAnnotation(MethodInfo.class);
if (methodInfo != null) {
if (methodInfo.hidden()) {
return; // skip methods marked as hidden
}
name = methodInfo.name();
} else {
name = (executable instanceof Method)? executable.getName(): "@init";
}
int count = executable.getParameterCount();
ArrayList<Executable>[] methodByArity = map.get(name);
if (methodByArity == null) {
@SuppressWarnings("unchecked")
ArrayList<Executable>[] freshMethodByArity =
(ArrayList<Executable>[]) new ArrayList<?>[1 + count];
ArrayList<Executable> list = new ArrayList<>();
list.add(executable);
freshMethodByArity[count] = list;
map.put(name, freshMethodByArity);
return;
}
if (methodByArity.length <= count) {
methodByArity = Arrays.copyOf(methodByArity, 1 + count);
ArrayList<Executable> list = new ArrayList<>();
list.add(executable);
methodByArity[count] = list;
map.put(name, methodByArity);
return;
}
ArrayList<Executable> list = methodByArity[count];
if (list == null) {
list = new ArrayList<>();
list.add(executable);
methodByArity[count] = list;
return;
}
list.add(executable);
});
}
}