package com.lambdaworks; import java.io.ByteArrayOutputStream; import java.io.PrintStream; import java.lang.management.ManagementFactory; import java.lang.management.ThreadInfo; import java.lang.management.ThreadMXBean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.rules.MethodRule; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.Statement; /** * @author Mark Paluch */ public class LoggingTestRule implements MethodRule { private boolean threadDumpOnFailure = false; public LoggingTestRule(boolean threadDumpOnFailure) { this.threadDumpOnFailure = threadDumpOnFailure; } @Override public Statement apply(Statement base, FrameworkMethod method, Object target) { return new Statement() { @Override public void evaluate() throws Throwable { Logger logger = LogManager.getLogger(method.getMethod().getDeclaringClass()); logger.info("---------------------------------------"); logger.info("-- Invoke method " + method.getMethod().getDeclaringClass().getSimpleName() + "." + method.getName()); logger.info("---------------------------------------"); try { base.evaluate(); } catch (Throwable t) { if (threadDumpOnFailure) { printThreadDump(logger); } throw t; } finally { logger.info("---------------------------------------"); logger.info("-- Finished method " + method.getMethod().getDeclaringClass().getSimpleName() + "." + method.getName()); logger.info("---------------------------------------"); } } }; } private void printThreadDump(Logger logger) { logger.info("---------------------------------------"); ByteArrayOutputStream buffer = getThreadDump(); logger.info("-- Thread dump: " + buffer.toString()); logger.info("---------------------------------------"); } private ByteArrayOutputStream getThreadDump() { ThreadMXBean threadBean = ManagementFactory.getThreadMXBean(); long[] threadIds = threadBean.getAllThreadIds(); ByteArrayOutputStream buffer = new ByteArrayOutputStream(); PrintStream stream = new PrintStream(buffer); for (long tid : threadIds) { ThreadInfo info = threadBean.getThreadInfo(tid, 50); if (info == null) { stream.println(" Inactive"); continue; } stream.println("Thread " + getTaskName(info.getThreadId(), info.getThreadName()) + ":"); Thread.State state = info.getThreadState(); stream.println(" State: " + state); stream.println(" Blocked count: " + info.getBlockedCount()); stream.println(" Waited count: " + info.getWaitedCount()); if (state == Thread.State.WAITING) { stream.println(" Waiting on " + info.getLockName()); } else if (state == Thread.State.BLOCKED) { stream.println(" Blocked on " + info.getLockName()); stream.println(" Blocked by " + getTaskName(info.getLockOwnerId(), info.getLockOwnerName())); } stream.println(" Stack:"); for (StackTraceElement frame : info.getStackTrace()) { stream.println(" " + frame.toString()); } } stream.flush(); return buffer; } private static String getTaskName(long id, String name) { if (name == null) { return Long.toString(id); } return id + " (" + name + ")"; } }