/*
* Copyright 2014 NAVER Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.navercorp.pinpoint.profiler.instrument.aspect;
import com.navercorp.pinpoint.bootstrap.instrument.aspect.Aspect;
import com.navercorp.pinpoint.bootstrap.instrument.aspect.JointPoint;
import com.navercorp.pinpoint.bootstrap.instrument.aspect.PointCut;
import com.navercorp.pinpoint.profiler.instrument.MethodNameReplacer;
import com.navercorp.pinpoint.profiler.instrument.interceptor.CodeBuilder;
import javassist.*;
import javassist.expr.ExprEditor;
import javassist.expr.MethodCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
/**
* @author emeroad
*/
public class AspectWeaverClass {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private static final MethodNameReplacer DEFAULT_METHOD_NAME_REPLACER = new DefaultMethodNameReplacer();
private final MethodNameReplacer methodNameReplacer;
public AspectWeaverClass() {
methodNameReplacer = DEFAULT_METHOD_NAME_REPLACER;
}
public void weaving(CtClass sourceClass, CtClass adviceClass) throws NotFoundException, CannotCompileException {
if (logger.isInfoEnabled()) {
logger.info("weaving sourceClass:{} advice:{}", sourceClass.getName(), adviceClass.getName());
}
if (!isAspectClass(adviceClass)) {
throw new RuntimeException("@Aspect not found. adviceClass:" + adviceClass);
}
// advice class hierarchy check,
final boolean isSubClass = adviceClass.subclassOf(sourceClass);
if (!isSubClass) {
final CtClass superClass = adviceClass.getSuperclass();
if (!superClass.getName().equals("java.lang.Object")) {
throw new CannotCompileException("invalid class hierarchy. " + sourceClass.getName() + " adviceSuperClass:" + superClass.getName());
}
}
copyUtilMethod(sourceClass, adviceClass);
final List<CtMethod> pointCutMethodList = findAnnotationMethod(adviceClass, PointCut.class);
final List<CtMethod> jointPointList = findAnnotationMethod(adviceClass, JointPoint.class);
for (CtMethod adviceMethod : pointCutMethodList) {
final CtMethod sourceMethod = sourceClass.getDeclaredMethod(adviceMethod.getName(), adviceMethod.getParameterTypes());
if (!sourceMethod.getSignature().equals(adviceMethod.getSignature())) {
throw new CannotCompileException("Signature miss match. method:" + adviceMethod.getName() + " source:" + sourceMethod.getSignature() + " advice:" + adviceMethod.getSignature());
}
if (logger.isInfoEnabled()) {
logger.info("weaving method:{}{}", sourceMethod.getName(), sourceMethod.getSignature());
}
weavingMethod(sourceClass, sourceMethod, adviceMethod, jointPointList, isSubClass);
}
}
private void copyUtilMethod(CtClass sourceClass, CtClass adviceClass) throws CannotCompileException {
final List<CtMethod> utilMethodList = findUtilMethod(adviceClass);
for (CtMethod method : utilMethodList) {
final CtMethod copyMethod = CtNewMethod.copy(method, method.getName(), sourceClass, null);
sourceClass.addMethod(copyMethod);
}
}
private List<CtMethod> findUtilMethod(CtClass adviceClass) throws CannotCompileException {
List<CtMethod> utilMethodList = new ArrayList<CtMethod>();
for (CtMethod method : adviceClass.getDeclaredMethods()) {
if (method.hasAnnotation(PointCut.class) || method.hasAnnotation(JointPoint.class)) {
continue;
}
int modifiers = method.getModifiers();
if (!Modifier.isPrivate(modifiers)) {
throw new CannotCompileException("non private UtilMethod unsupported. method:" + method.getLongName());
}
utilMethodList.add(method);
}
return utilMethodList;
}
private boolean isAspectClass(CtClass aspectClass) {
return aspectClass.hasAnnotation(Aspect.class);
}
private void weavingMethod(CtClass sourceClass, CtMethod sourceMethod, CtMethod adviceMethod, List<CtMethod> jointPointList, boolean isSubClass) throws CannotCompileException {
final CtMethod copyMethod = copyMethod(sourceClass, sourceMethod);
sourceClass.addMethod(copyMethod);
sourceMethod.setBody(adviceMethod, null);
sourceMethod.instrument(new JointPointMethodEditor(sourceClass, sourceMethod, copyMethod, jointPointList, isSubClass));
}
public class JointPointMethodEditor extends ExprEditor {
private final CtClass sourceClass;
private final CtMethod sourceMethod;
private final CtMethod replaceMethod;
private final List<CtMethod> jointPointList;
private final boolean isSubClass;
public JointPointMethodEditor(CtClass sourceClass, CtMethod sourceMethod, CtMethod replaceMethod, List<CtMethod> jointPointList, boolean isSubClass) {
if (replaceMethod == null) {
throw new NullPointerException("replaceMethod must not be null");
}
this.sourceClass = sourceClass;
this.sourceMethod = sourceMethod;
this.replaceMethod = replaceMethod;
this.jointPointList = jointPointList;
this.isSubClass = isSubClass;
}
@Override
public void edit(MethodCall methodCall) throws CannotCompileException {
final boolean joinPointMethod = isJoinPointMethod(jointPointList, methodCall.getMethodName(), methodCall.getSignature());
if (joinPointMethod) {
if (!methodCall.getSignature().equals(replaceMethod.getSignature())) {
throw new CannotCompileException("Signature miss match. method:" + sourceMethod.getName() + " source:" + sourceMethod.getSignature() + " jointPoint:" + replaceMethod.getSignature());
}
final String invokeSource = invokeSourceMethod();
if (logger.isDebugEnabled()) {
logger.debug("JointPoint method {}{} -> invokeOriginal:{}", methodCall.getMethodName(), methodCall.getSignature(), invokeSource);
}
methodCall.replace(invokeSource);
} else {
if (isSubClass) {
// validate super class method
try {
CtMethod method = methodCall.getMethod();
CtClass declaringClass = method.getDeclaringClass();
if (sourceClass.subclassOf(declaringClass)) {
sourceClass.getMethod(methodCall.getMethodName(), methodCall.getSignature());
}
} catch (NotFoundException e) {
throw new CannotCompileException(e.getMessage(), e);
}
}
}
}
private boolean isJoinPointMethod(List<CtMethod> jointPointList, String methodName, String methodSignature) {
for (CtMethod method : jointPointList) {
if (method.getName().equals(methodName) && method.getSignature().equals(methodSignature)) {
return true;
}
}
return false;
}
private String invokeSourceMethod() {
CodeBuilder builder = new CodeBuilder(32);
if (!isVoid(replaceMethod.getSignature())) {
builder.append("$_=");
}
builder.format("%1$s($$);", methodNameReplacer.replaceMethodName(sourceMethod.getName()));
return builder.toString();
}
public boolean isVoid(String signature) {
return signature.endsWith("V");
}
}
private CtMethod copyMethod(CtClass sourceClass, CtMethod sourceMethod) throws CannotCompileException {
// need id?
String copyMethodName = methodNameReplacer.replaceMethodName(sourceMethod.getName());
final CtMethod copy = CtNewMethod.copy(sourceMethod, copyMethodName, sourceClass, null);
// set private
final int modifiers = copy.getModifiers();
copy.setModifiers(Modifier.setPrivate(modifiers));
return copy;
}
private List<CtMethod> findAnnotationMethod(CtClass ctClass, Class annotation) {
if (ctClass == null) {
throw new NullPointerException("ctClass must not be null");
}
if (annotation == null) {
throw new NullPointerException("annotation must not be null");
}
final List<CtMethod> annotationList = new ArrayList<CtMethod>();
for (CtMethod method : ctClass.getDeclaredMethods()) {
if (method.hasAnnotation(annotation)) {
annotationList.add(method);
}
}
return annotationList;
}
public static class DefaultMethodNameReplacer implements MethodNameReplacer {
public static final String PREFIX = "__";
public static final String POSTFIX = "_$$pinpoint";
public String replaceMethodName(String methodName) {
if (methodName == null) {
throw new NullPointerException("methodName must not be null");
}
return PREFIX + methodName + POSTFIX;
}
}
}