/**
*
*/
package org.activejpa.enhancer;
import java.io.IOException;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.CtNewMethod;
import javassist.LoaderClassPath;
import javassist.NotFoundException;
import javax.persistence.Entity;
import org.activejpa.entity.Filter;
import org.activejpa.entity.Model;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author ganeshs
*
*/
public class DomainClassEnhancer {
private static final Logger logger = LoggerFactory.getLogger(DomainClassEnhancer.class);
private static Map<ClassLoader, Context> contextMap = new HashMap<ClassLoader, Context>();
public byte[] enhance(ClassLoader loader, String className) {
Context context = getContext(loader);
className = className.replace("/", ".");
try {
logger.trace("Attempting to enhance the class - " + className);
if (! context.isClassLoaded(className)) {
CtClass ctClass = context.getCtClass(className);
if (! canEnhance(context, ctClass)) {
return null;
}
logger.info("Transforming the class - " + className);
ctClass.defrost();
createModelMethods(context, ctClass);
byte[] byteCode = ctClass.toBytecode();
context.addClass(className);
return byteCode;
} else {
logger.info("Class already enhanced - " + className);
return null;
}
} catch (NotFoundException e) {
// Can't do much. Just log and ignore
logger.trace("Failed while transforming the class " + className, e);
} catch (Exception e) {
logger.error("Failed while transforming the class " + className, e);
throw new RuntimeException("Failed while transforming the class " + className, e);
}
return null;
}
public Context getContext(ClassLoader classLoader) {
Context context = contextMap.get(classLoader);
Context parent = null;
if (context == null) {
if (classLoader != null) {
if (classLoader.getParent() != null) {
parent = getContext(classLoader.getParent());
}
}
context = new Context(parent, classLoader);
contextMap.put(classLoader, context);
}
return context;
}
public boolean canEnhance(String className) {
Context context = getContext(Thread.currentThread().getContextClassLoader());
try {
return canEnhance(context, context.getCtClass(className));
} catch (Exception e) {
logger.trace("Error while checking is the class can be enhanced", e);
return false;
}
}
protected boolean canEnhance(Context context, CtClass ctClass) throws IOException, NotFoundException {
return isEntity(ctClass) && isExtendingModel(context, ctClass);
}
private void createModelMethods(Context context, CtClass ctClass) throws CannotCompileException {
createMethod(context, ctClass, "findById", Model.class.getName(), "java.io.Serializable id");
createMethod(context, ctClass, "all", "java.util.List");
createMethod(context, ctClass, "count", "long");
createMethod(context, ctClass, "count", "long", Filter.class.getName() + " filter");
createMethod(context, ctClass, "deleteAll", "void");
createMethod(context, ctClass, "deleteAll", "void", Filter.class.getName() + " filter");
createMethod(context, ctClass, "exists", "boolean", "java.io.Serializable id");
createMethod(context, ctClass, "where", "java.util.List", "Object[] paramValues");
createMethod(context, ctClass, "where", "java.util.List", Filter.class.getName() + " filter");
createMethod(context, ctClass, "one", Model.class.getName(), "Object[] paramValues");
createMethod(context, ctClass, "first", Model.class.getName(), "Object[] paramValues");
}
private void createMethod(Context context, CtClass ctClass, String methodName, String returnType, String... arguments) throws CannotCompileException {
logger.info("Creating the method - " + methodName + " under the class - " + ctClass.getName());
StringWriter writer = new StringWriter();
writer.append("public static ").append(returnType).append(" ").append(methodName).append("(");
if (arguments != null && arguments.length > 0) {
for (int i = 0; i < arguments.length - 1; i++) {
writer.append(arguments[i]).append(", ");
}
writer.append(arguments[arguments.length - 1]);
}
writer.append(") {");
if (! returnType.equals("void")) {
writer.append("return (" + returnType + ")");
}
writer.append(methodName).append("(").append(ctClass.getName()).append(".class");
if (arguments != null && arguments.length > 0) {
for (int i = 0; i < arguments.length; i++) {
writer.append(", ").append(arguments[i].split(" ")[1]);
}
}
writer.append(");}");
CtMethod method = null;
try {
method = getMethod(context, ctClass, methodName, arguments);
if (method != null) {
ctClass.removeMethod(method);
}
} catch (NotFoundException e) {
logger.trace("Failed to get the method " + methodName, e);
// Just ignore if the method doesn't exist already
}
logger.debug("Method src - " + writer.toString());
method = CtNewMethod.make(writer.toString(), ctClass);
ctClass.addMethod(method);
}
private CtMethod getMethod(Context context, CtClass ctClass, String methodName, String... arguments) throws NotFoundException {
List<CtClass> paramTypes = new ArrayList<CtClass>();
if (arguments != null) {
for (String argument : arguments) {
paramTypes.add(context.getCtClass(argument.split(" ")[0]));
}
}
return ctClass.getDeclaredMethod(methodName, paramTypes.toArray(new CtClass[0]));
}
protected boolean isEntity(CtClass ctClass) throws IOException {
return ctClass.hasAnnotation(Entity.class);
}
protected boolean isExtendingModel(Context context, CtClass ctClass) throws NotFoundException {
return getSuperClasses(ctClass).contains(context.getCtClass(Model.class.getName()));
}
/**
* Returns the super classes from top to bottom. The {@link Object} class name will always be returned at index 0.
*
* @param className
* @return
* @throws IOException
*/
protected List<CtClass> getSuperClasses(CtClass modelClass) throws NotFoundException {
List<CtClass> superClasses = new ArrayList<CtClass>();
CtClass superClass = getSuperClass(modelClass);
if (superClass != null) {
superClasses.addAll(getSuperClasses(superClass));
superClasses.add(superClass);
}
return superClasses;
}
protected CtClass getSuperClass(CtClass modelClass) throws NotFoundException {
return modelClass.getSuperclass();
}
public static class Context {
private Context parent;
private Set<String> loadedClasses = new HashSet<String>();
private ClassPool classPool;
public Context(Context parent, ClassLoader loader) {
this.parent = parent;
if (parent == null) {
classPool = ClassPool.getDefault();
} else {
classPool = new ClassPool(parent.classPool);
classPool.appendClassPath(new LoaderClassPath(loader));
}
}
/**
* @return the parent
*/
public Context getParent() {
return parent;
}
/**
* @param parent the parent to set
*/
public void setParent(Context parent) {
this.parent = parent;
}
/**
* @return the loadedClasses
*/
public Set<String> getLoadedClasses() {
return loadedClasses;
}
/**
* @param loadedClasses the loadedClasses to set
*/
public void setLoadedClasses(Set<String> loadedClasses) {
this.loadedClasses = loadedClasses;
}
/**
* @return the classPool
*/
public ClassPool getClassPool() {
return classPool;
}
/**
* @param classPool the classPool to set
*/
public void setClassPool(ClassPool classPool) {
this.classPool = classPool;
}
public boolean isClassLoaded(String className) {
if (loadedClasses.contains(className)) {
return true;
}
if (parent != null) {
return parent.isClassLoaded(className);
}
return false;
}
public CtClass getCtClass(String className) throws NotFoundException {
return classPool.get(className);
}
public void addClass(String className) {
loadedClasses.add(className);
}
}
}