/* * Copyright 2014 NAVER Corp. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.navercorp.pinpoint.test.junit4; import java.lang.reflect.Method; import com.navercorp.pinpoint.test.MockApplicationContext; import org.junit.internal.runners.model.EachTestNotifier; import org.junit.runner.notification.RunNotifier; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; import org.junit.runners.model.TestClass; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.navercorp.pinpoint.bootstrap.context.SpanRecorder; import com.navercorp.pinpoint.bootstrap.context.Trace; import com.navercorp.pinpoint.bootstrap.context.TraceContext; import com.navercorp.pinpoint.common.trace.ServiceType; /** * @author hyungil.jeong * @author emeroad */ public final class PinpointJUnit4ClassRunner extends BlockJUnit4ClassRunner { private static final Logger logger = LoggerFactory.getLogger(PinpointJUnit4ClassRunner.class); private static TestContext testContext; public PinpointJUnit4ClassRunner(Class<?> clazz) throws InitializationError { super(clazz); if (logger.isDebugEnabled()) { logger.debug("PinpointJUnit4ClassRunner constructor called with [{}].", clazz); } } private void beforeTestClass() { try { // TODO fix static TestContext if (testContext == null) { logger.debug("traceContext is null"); testContext = new TestContext(); } } catch (Throwable ex) { throw new RuntimeException(ex.getMessage(), ex); } } protected TestClass createTestClass(Class<?> testClass) { logger.debug("createTestClass {}", testClass); beforeTestClass(); return testContext.createTestClass(testClass); } private TraceContext getTraceContext() { MockApplicationContext mockApplicationContext = testContext.getMockApplicationContext(); return mockApplicationContext.getTraceContext(); } @Override protected void runChild(FrameworkMethod method, RunNotifier notifier) { beginTracing(method); final Thread thread = Thread.currentThread(); ClassLoader originalClassLoader = thread.getContextClassLoader(); try { thread.setContextClassLoader(testContext.getClassLoader()); super.runChild(method, notifier); } finally { thread.setContextClassLoader(originalClassLoader); endTracing(method, notifier); } } private void beginTracing(FrameworkMethod method) { if (shouldCreateNewTraceObject(method)) { TraceContext traceContext = getTraceContext(); Trace trace = traceContext.newTraceObject(); SpanRecorder recorder = trace.getSpanRecorder(); recorder.recordServiceType(ServiceType.TEST); } } private void endTracing(FrameworkMethod method, RunNotifier notifier) { if (shouldCreateNewTraceObject(method)) { TraceContext traceContext = getTraceContext(); try { Trace trace = traceContext.currentRawTraceObject(); if (trace == null) { // Trace is already detached from the ThreadLocal storage. // Happens when root trace method is tested without @IsRootSpan. EachTestNotifier testMethodNotifier = new EachTestNotifier(notifier, super.describeChild(method)); String traceObjectAlreadyDetachedMessage = "Trace object already detached. If you're testing a trace root, please add @IsRootSpan to the test method"; testMethodNotifier.addFailure(new IllegalStateException(traceObjectAlreadyDetachedMessage)); } else { trace.close(); } } finally { traceContext.removeTraceObject(); } } } private boolean shouldCreateNewTraceObject(FrameworkMethod method) { IsRootSpan isRootSpan = method.getAnnotation(IsRootSpan.class); return isRootSpan == null || !isRootSpan.value(); } @Override protected Statement methodInvoker(FrameworkMethod method, Object test) { return super.methodInvoker(method, test); } @Override protected Statement withBefores(FrameworkMethod method, final Object target, Statement statement) { Statement before = super.withBefores(method, target, statement); BeforeCallbackStatement callbackStatement = new BeforeCallbackStatement(before, new Statement() { @Override public void evaluate() throws Throwable { setupBaseTest(target); } }); return callbackStatement; } private void setupBaseTest(Object test) { logger.debug("setupBaseTest"); // It's safe to cast final Class<?> baseTestClass = testContext.getBaseTestClass(); if (baseTestClass.isInstance(test)) { try { Method reset = baseTestClass.getDeclaredMethod("setup", TestContext.class); reset.invoke(test, testContext); } catch (Exception e) { throw new RuntimeException("setCurrentHolder Error. Caused by:" + e.getMessage(), e); } } } @Override protected Statement withBeforeClasses(Statement statement) { final Statement beforeClasses = super.withBeforeClasses(statement); return new BeforeCallbackStatement(beforeClasses, new Statement() { @Override public void evaluate() throws Throwable { beforeClass(); } }); } public void beforeClass() throws Throwable { logger.debug("beforeClass"); // TODO MockApplicationContext.start(); } @Override protected Statement withAfterClasses(Statement statement) { final Statement afterClasses = super.withAfterClasses(statement); return new AfterCallbackStatement(afterClasses, new Statement() { @Override public void evaluate() throws Throwable { afterClass(); } }); } public void afterClass() throws Throwable { logger.debug("afterClass"); // TODO MockApplicationContext.close() } }