/* * Copyright 2016 NAVER Corp. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.navercorp.pinpoint.profiler.instrument; import com.navercorp.pinpoint.bootstrap.instrument.InstrumentContext; import com.navercorp.pinpoint.profiler.util.JavaAssistUtils; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.MethodNode; import org.objectweb.asm.util.CheckClassAdapter; import org.objectweb.asm.util.TraceClassVisitor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.InputStream; import java.io.PrintWriter; import java.util.List; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * @author jaehong.kim */ public class ASMClassNodeLoader { private final static InstrumentContext pluginContext = mock(InstrumentContext.class); static { when(pluginContext.injectClass(any(ClassLoader.class), any(String.class))).thenAnswer(new Answer<Class<?>>() { @Override public Class<?> answer(InvocationOnMock invocation) throws Throwable { ClassLoader loader = (ClassLoader) invocation.getArguments()[0]; String name = (String) invocation.getArguments()[1]; return loader.loadClass(name); } }); when(pluginContext.getResourceAsStream(any(ClassLoader.class), any(String.class))).thenAnswer(new Answer<InputStream>() { @Override public InputStream answer(InvocationOnMock invocation) throws Throwable { ClassLoader loader = (ClassLoader) invocation.getArguments()[0]; String name = (String) invocation.getArguments()[1]; if(loader == null) { loader = ClassLoader.getSystemClassLoader(); } return loader.getResourceAsStream(name); } }); } // only use for test. public static ClassNode get(final String className) throws Exception { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); ClassReader cr = new ClassReader(classLoader.getResourceAsStream(JavaAssistUtils.javaNameToJvmName(className) + ".class")); ClassNode classNode = new ClassNode(); cr.accept(classNode, ClassReader.EXPAND_FRAMES); return classNode; } public static MethodNode get(final String classInternalName, final String methodName) throws Exception { ClassNode classNode = get(classInternalName); List<MethodNode> methodNodes = classNode.methods; for (MethodNode methodNode : methodNodes) { if (methodNode.name.equals(methodName)) { return methodNode; } } return null; } public static TestClassLoader getClassLoader() { return new TestClassLoader(); } public static class TestClassLoader extends ClassLoader { private final Logger logger = LoggerFactory.getLogger(this.getClass()); private String targetClassName; private String targetMethodName; private CallbackHandler callbackHandler; private boolean trace; private boolean verify; public void setTargetClassName(String targetClassName) { this.targetClassName = targetClassName; } public void setCallbackHandler(CallbackHandler callbackHandler) { this.callbackHandler = callbackHandler; } public void setTrace(boolean trace) { this.trace = trace; } public void setVerify(boolean verify) { this.verify = verify; } @Override public Class<?> loadClass(final String name) throws ClassNotFoundException { if ((targetClassName == null || name.equals(targetClassName))) { try { ClassNode classNode = ASMClassNodeLoader.get(JavaAssistUtils.javaNameToJvmName(name)); if (this.trace) { logger.debug("## original #############################################################"); ASMClassWriter cw = new ASMClassWriter(pluginContext, classNode.name, classNode.superName, 0, null); TraceClassVisitor tcv = new TraceClassVisitor(cw, new PrintWriter(System.out)); classNode.accept(tcv); } if (callbackHandler != null) { callbackHandler.handle(classNode); } ASMClassWriter cw = new ASMClassWriter(pluginContext, classNode.name, classNode.superName, ClassWriter.COMPUTE_FRAMES, null); if (this.trace) { logger.debug("## modified #############################################################"); TraceClassVisitor tcv = new TraceClassVisitor(cw, new PrintWriter(System.out)); classNode.accept(tcv); } else { classNode.accept(cw); } byte[] bytecode = cw.toByteArray(); if (this.verify) { CheckClassAdapter.verify(new ClassReader(bytecode), false, new PrintWriter(System.out)); } return super.defineClass(name, bytecode, 0, bytecode.length); } catch (Exception ex) { throw new ClassNotFoundException("Load error: " + ex.toString(), ex); } } return super.loadClass(name); } } public interface CallbackHandler { void handle(ClassNode classNode); } }