/* * Copyright (C) 2015 SoftIndex LLC. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.datakernel.codegen; import io.datakernel.codegen.utils.DefiningClassWriter; import io.datakernel.codegen.utils.Preconditions; import org.objectweb.asm.Type; import org.objectweb.asm.commons.GeneratorAdapter; import org.objectweb.asm.commons.Method; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.FileOutputStream; import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.nio.file.Path; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import static io.datakernel.codegen.Utils.loadAndCast; import static java.util.Arrays.asList; import static org.objectweb.asm.Opcodes.*; import static org.objectweb.asm.Type.getInternalName; import static org.objectweb.asm.Type.getType; import static org.objectweb.asm.commons.Method.getMethod; /** * Intends for dynamic description of the behaviour of the object in runtime * * @param <T> type of item */ @SuppressWarnings("unchecked") public final class ClassBuilder<T> { private final Logger logger = LoggerFactory.getLogger(this.getClass()); public static final String DEFAULT_CLASS_NAME = ClassBuilder.class.getPackage().getName() + ".Class"; private static final AtomicInteger COUNTER = new AtomicInteger(); private final DefiningClassLoader classLoader; private final Class<T> mainClass; private final List<Class<?>> otherClasses; private Path bytecodeSaveDir; private String className; private final Map<String, Class<?>> fields = new LinkedHashMap<>(); private final Map<String, Class<?>> staticFields = new LinkedHashMap<>(); private final Map<String, Object> staticConstants = new LinkedHashMap<>(); private final Map<Method, Expression> methods = new LinkedHashMap<>(); private final Map<Method, Expression> staticMethods = new LinkedHashMap<>(); public static class AsmClassKey<T> { private final Class<T> mainClass; private final List<Class<?>> otherClasses; private final Map<String, Class<?>> fields; private final Map<Method, Expression> expressionMap; private final Map<Method, Expression> expressionStaticMap; public AsmClassKey(Class<T> mainClass, List<Class<?>> otherClasses, Map<String, Class<?>> fields, Map<Method, Expression> expressionMap, Map<Method, Expression> expressionStaticMap) { this.mainClass = mainClass; this.otherClasses = otherClasses; this.fields = fields; this.expressionMap = expressionMap; this.expressionStaticMap = expressionStaticMap; } public Class<T> getMainClass() { return mainClass; } public List<Class<?>> getOtherClasses() { return otherClasses; } @Override public String toString() { return "AsmClassKey{" + "mainType=" + mainClass + ", otherTypes" + otherClasses + ", fields=" + fields + ", expressionMap=" + expressionMap + ", expressionStaticMap=" + expressionStaticMap + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AsmClassKey that = (AsmClassKey) o; return Objects.equals(mainClass, that.mainClass) && Objects.equals(otherClasses, that.otherClasses) && Objects.equals(fields, that.fields) && Objects.equals(expressionMap, that.expressionMap) && Objects.equals(expressionStaticMap, that.expressionStaticMap); } @Override public int hashCode() { return Objects.hash(mainClass, otherClasses, fields, expressionMap, expressionStaticMap); } } // region builders /** * Creates a new instance of AsmFunctionFactory * * @param classLoader class loader * @param type type of dynamic class */ private ClassBuilder(DefiningClassLoader classLoader, Class<T> type) { this(classLoader, type, Collections.EMPTY_LIST); } private ClassBuilder(DefiningClassLoader classLoader, Class<T> mainType, List<Class<?>> types) { this.classLoader = classLoader; this.mainClass = mainType; this.otherClasses = types; } public static <T> ClassBuilder<T> create(DefiningClassLoader classLoader, Class<T> type) { return new ClassBuilder<>(classLoader, type); } public static <T> ClassBuilder<T> create(DefiningClassLoader classLoader, Class<T> mainType, List<Class<?>> types) { return new ClassBuilder<T>(classLoader, mainType, types); } public ClassBuilder<T> withBytecodeSaveDir(Path bytecodeSaveDir) { this.bytecodeSaveDir = bytecodeSaveDir; return this; } /** * Creates a new field for a dynamic class * * @param field name of field * @param fieldClass type of field * @return changed AsmFunctionFactory */ public ClassBuilder<T> withField(String field, Class<?> fieldClass) { fields.put(field, fieldClass); return this; } public ClassBuilder<T> withFields(Map<String, Class<?>> fieldsClasses) { fields.putAll(fieldsClasses); return this; } /** * Creates a new method for a dynamic class * * @param method new method for class * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> withMethod(Method method, Expression expression) { methods.put(method, expression); return this; } public ClassBuilder<T> withMethods(Map<Method, Expression> methods) { this.methods.putAll(methods); return this; } public ClassBuilder<T> withStaticMethod(Method method, Expression expression) { this.staticMethods.put(method, expression); return this; } public ClassBuilder<T> withStaticMethods(Map<Method, Expression> staticMethods) { this.staticMethods.putAll(staticMethods); return this; } /** * Creates a new method for a dynamic class * * @param methodName name of method * @param returnClass type which returns this method * @param argumentTypes list of types of arguments * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> withMethod(String methodName, Class<?> returnClass, List<? extends Class<?>> argumentTypes, Expression expression) { Type[] types = new Type[argumentTypes.size()]; for (int i = 0; i < argumentTypes.size(); i++) { types[i] = getType(argumentTypes.get(i)); } return withMethod(new Method(methodName, getType(returnClass), types), expression); } public ClassBuilder<T> withStaticMethod(String methodName, Class<?> returnClass, List<? extends Class<?>> argumentTypes, Expression expression) { Type[] types = new Type[argumentTypes.size()]; for (int i = 0; i < argumentTypes.size(); i++) { types[i] = getType(argumentTypes.get(i)); } return withStaticMethod(new Method(methodName, getType(returnClass), types), expression); } public ClassBuilder<T> withStaticField(String fieldName, Class<?> type, Object value) { this.staticFields.put(fieldName, type); this.staticConstants.put(fieldName, value); return this; } /** * CCreates a new method for a dynamic class * * @param methodName name of method * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> withMethod(String methodName, Expression expression) { if (methodName.contains("(")) { Method method = Method.getMethod(methodName); return withMethod(method, expression); } Method foundMethod = null; LinkedHashSet<java.lang.reflect.Method> methods = new LinkedHashSet<>(); List<List<java.lang.reflect.Method>> listOfMethods = new ArrayList<>(); listOfMethods.add(asList(Object.class.getMethods())); listOfMethods.add(asList(mainClass.getMethods())); listOfMethods.add(asList(mainClass.getDeclaredMethods())); for (Class<?> type : otherClasses) { listOfMethods.add(asList(type.getMethods())); listOfMethods.add(asList(type.getDeclaredMethods())); } for (List<java.lang.reflect.Method> list : listOfMethods) { for (java.lang.reflect.Method m : list) { if (m.getName().equals(methodName)) { Method method = getMethod(m); if (foundMethod != null && !method.equals(foundMethod)) throw new IllegalArgumentException("Method " + method + " collides with " + foundMethod); foundMethod = method; } } } Preconditions.check(foundMethod != null, "Could not find method '" + methodName + "'"); return withMethod(foundMethod, expression); } public ClassBuilder<T> withMethod(Map<String, Expression> expressions) { ClassBuilder<T> self = this; for (String methodName : expressions.keySet()) { self = self.withMethod(methodName, expressions.get(methodName)); } return self; } public ClassBuilder<T> withClassName(String name) { this.className = name; return this; } // endregion public Class<T> build() { synchronized (classLoader) { AsmClassKey key = new AsmClassKey(mainClass, otherClasses, fields, methods, staticMethods); Class<?> cachedClass = classLoader.getClassByKey(key); if (cachedClass != null) { logger.trace("Fetching {} for key {} from cache", cachedClass, key); return (Class<T>) cachedClass; } Class<T> newClass = defineNewClass(key); for (String staticField : staticConstants.keySet()) { Object staticValue = staticConstants.get(staticField); try { Field field = newClass.getField(staticField); field.set(null, staticValue); } catch (NoSuchFieldException | IllegalAccessException e) { throw new AssertionError(e); } } return newClass; } } private Class<T> defineNewClass(AsmClassKey key) { DefiningClassWriter cw = DefiningClassWriter.create(classLoader); String actualClassName; if (className == null) { actualClassName = DEFAULT_CLASS_NAME + COUNTER.incrementAndGet(); } else { actualClassName = className; } Type classType = getType('L' + actualClassName.replace('.', '/') + ';'); final String[] internalNames = new String[1 + otherClasses.size()]; internalNames[0] = getInternalName(mainClass); for (int i = 0; i < otherClasses.size(); i++) { internalNames[1 + i] = getInternalName(otherClasses.get(i)); } if (mainClass.isInterface()) { cw.visit(V1_6, ACC_PUBLIC + ACC_FINAL + ACC_SUPER, classType.getInternalName(), null, "java/lang/Object", internalNames); } else { cw.visit(V1_6, ACC_PUBLIC + ACC_FINAL + ACC_SUPER, classType.getInternalName(), null, internalNames[0], Arrays.copyOfRange(internalNames, 1, internalNames.length)); } { Method m = getMethod("void <init> ()"); GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC, m, null, null, cw); g.loadThis(); if (mainClass.isInterface()) { g.invokeConstructor(getType(Object.class), m); } else { g.invokeConstructor(getType(mainClass), m); } g.returnValue(); g.endMethod(); } for (String field : fields.keySet()) { Class<?> fieldClass = fields.get(field); cw.visitField(ACC_PUBLIC, field, getType(fieldClass).getDescriptor(), null, null); } for (Method m : staticMethods.keySet()) { try { GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC + ACC_STATIC + ACC_FINAL, m, null, null, cw); Context ctx = new Context(classLoader, g, classType, mainClass, otherClasses, fields, staticConstants, m.getArgumentTypes(), m, methods, staticMethods); Expression expression = staticMethods.get(m); loadAndCast(ctx, expression, m.getReturnType()); g.returnValue(); g.endMethod(); } catch (Exception e) { throw new RuntimeException(e); } } for (Method m : methods.keySet()) { try { GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC, m, null, null, cw); Context ctx = new Context(classLoader, g, classType, mainClass, otherClasses, fields, staticConstants, m.getArgumentTypes(), m, methods, staticMethods); Expression expression = methods.get(m); loadAndCast(ctx, expression, m.getReturnType()); g.returnValue(); g.endMethod(); } catch (Exception e) { throw new RuntimeException(e); } } for (String staticField : staticFields.keySet()) { cw.visitField(ACC_PUBLIC + ACC_STATIC, staticField, getType(staticFields.get(staticField)).getDescriptor(), null, null); } for (String staticField : staticConstants.keySet()) { cw.visitField(ACC_PUBLIC + ACC_STATIC, staticField, getType(staticConstants.get(staticField).getClass()).getDescriptor(), null, null); } if (bytecodeSaveDir != null) { try (FileOutputStream fos = new FileOutputStream(bytecodeSaveDir.resolve(actualClassName + ".class").toFile())) { fos.write(cw.toByteArray()); } catch (IOException e) { throw new RuntimeException(e); } } cw.visitEnd(); Class<?> definedClass = classLoader.defineClass(actualClassName, key, cw.toByteArray()); logger.trace("Defined new {} for key {}", definedClass, key); return (Class<T>) definedClass; } public T buildClassAndCreateNewInstance() { try { return build().newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new RuntimeException(e); } } public T buildClassAndCreateNewInstance(Object... constructorParameters) { Class[] constructorParameterTypes = new Class[constructorParameters.length]; for (int i = 0; i < constructorParameters.length; i++) { constructorParameterTypes[i] = constructorParameters[i].getClass(); } return buildClassAndCreateNewInstance(constructorParameterTypes, constructorParameters); } public T buildClassAndCreateNewInstance(Class[] constructorParameterTypes, Object[] constructorParameters) { try { return build().getConstructor(constructorParameterTypes).newInstance(constructorParameters); } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { throw new RuntimeException(e); } } }