/*
* Original work Copyright 1999-2017 The Apache Software Foundation
* Modified work Copyright (c) 2017, Hazelcast, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package classloading;
import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.ConcurrentModificationException;
import java.util.Iterator;
import static java.lang.String.format;
import static org.junit.Assert.fail;
/**
* Contains detection logic for {@link ThreadLocal} leaks.
*
* Adapted from the WebappClassLoader of the Apache Tomcat project.
*
* @see <a href="https://github.com/apache/tomcat/blob/811450a84ca29e38d42e041be85b2deed4058ebb/java/org/apache/catalina/loader/WebappClassLoaderBase.java#L1823">WebappClassLoaderBase.java</a>
*/
public final class ThreadLocalLeakTestUtils {
/**
* Defines a list of value types which are explicitly being allowed to be detected in a {@link ThreadLocal}.
*/
private static final String[] ACCEPTED_THREAD_LOCAL_VALUE_TYPES = new String[]{
"org.mockito.configuration.DefaultMockitoConfiguration",
"org.mockito.internal.progress.MockingProgressImpl",
};
public static void checkThreadLocalsForLeaks(ClassLoader cl) throws Exception {
Thread[] threads = getThreads();
// make the fields in the Thread class that store ThreadLocals accessible
Field threadLocalsField = Thread.class.getDeclaredField("threadLocals");
threadLocalsField.setAccessible(true);
Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
inheritableThreadLocalsField.setAccessible(true);
// make the underlying array of ThreadLoad.ThreadLocalMap.Entry objects accessible
Class<?> tlmClass = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
Field tableField = tlmClass.getDeclaredField("table");
tableField.setAccessible(true);
Method expungeStaleEntriesMethod = tlmClass.getDeclaredMethod("expungeStaleEntries");
expungeStaleEntriesMethod.setAccessible(true);
for (Thread thread : threads) {
Object threadLocalMap;
if (thread != null) {
// clear the first map
threadLocalMap = threadLocalsField.get(thread);
if (threadLocalMap != null) {
expungeStaleEntriesMethod.invoke(threadLocalMap);
checkThreadLocalMapForLeaks(cl, threadLocalMap, tableField);
}
// clear the second map
threadLocalMap = inheritableThreadLocalsField.get(thread);
if (threadLocalMap != null) {
expungeStaleEntriesMethod.invoke(threadLocalMap);
checkThreadLocalMapForLeaks(cl, threadLocalMap, tableField);
}
}
}
}
/**
* Get the set of current threads as an array.
*/
private static Thread[] getThreads() {
// find the root thread group
ThreadGroup threadGroup = Thread.currentThread().getThreadGroup();
try {
while (threadGroup.getParent() != null) {
threadGroup = threadGroup.getParent();
}
} catch (SecurityException se) {
fail(format("Unable to obtain the parent for ThreadGroup [%s]. It will not be possible to check all threads"
+ " for potential memory leaks [%s]", threadGroup.getName(), se.getMessage()));
}
int threadCountGuess = threadGroup.activeCount() + 50;
Thread[] threads = new Thread[threadCountGuess];
int threadCountActual = threadGroup.enumerate(threads);
// make sure we don't miss any threads
while (threadCountActual == threadCountGuess) {
threadCountGuess *= 2;
threads = new Thread[threadCountGuess];
// note threadGroup.enumerate(Thread[]) silently ignores any threads that can't fit into the array
threadCountActual = threadGroup.enumerate(threads);
}
return threads;
}
/**
* Analyzes the given thread local map object. Also pass in the field that points
* to the internal table to save re-calculating it on every call to this method.
*/
private static void checkThreadLocalMapForLeaks(ClassLoader cl, Object map, Field internalTableField) throws Exception {
if (map == null) {
return;
}
Object[] table = (Object[]) internalTableField.get(map);
if (table == null) {
return;
}
for (Object obj : table) {
if (obj == null) {
continue;
}
boolean keyLoadedByApplication = false;
boolean valueLoadedByApplication = false;
// check the key
Object key = ((Reference<?>) obj).get();
if (cl.equals(key) || loadedByThisOrChild(key, cl)) {
keyLoadedByApplication = true;
}
// check the value
Field valueField = obj.getClass().getDeclaredField("value");
valueField.setAccessible(true);
Object value = valueField.get(obj);
if (cl.equals(value) || loadedByThisOrChild(value, cl)) {
valueLoadedByApplication = true;
}
if (keyLoadedByApplication || valueLoadedByApplication) {
Object[] args = new Object[4];
if (key != null) {
args[0] = getPrettyClassName(key.getClass());
try {
args[1] = key.toString();
} catch (Exception e) {
System.err.printf("Unable to determine string representation of key of type [%s]", args[0]);
args[1] = "unknown";
}
}
if (value != null) {
args[2] = getPrettyClassName(value.getClass());
try {
args[3] = value.toString();
} catch (Exception e) {
System.err.printf("Unable to determine string representation of value of type [%s]", args[2]);
args[3] = "unknown";
}
}
if (valueLoadedByApplication) {
String message = format("Application created a ThreadLocal with key of type [%s] (value [%s]) and a value of"
+ " type [%s] (value [%s) but failed to remove it when the application was stopped.",
args[0], args[1], args[2], args[3]);
for (String acceptedThreadLocal : ACCEPTED_THREAD_LOCAL_VALUE_TYPES) {
if (acceptedThreadLocal.equals(args[2])) {
System.out.println(message + " But the value type is explicitly allowed, so this is no failure.");
return;
}
}
fail(message);
} else if (value == null) {
System.out.printf("Application created a ThreadLocal with key of type [%s] (value [%s]). The ThreadLocal"
+ " has been correctly set to null and the key will be removed by GC.", args[0], args[1]);
} else {
System.out.printf("Application created a ThreadLocal with key of type [%s] (value [%s]) and a value of type "
+ " [%s] (value [%s]). Since keys are only weakly held by the ThreadLocalMap this is not a memory"
+ " leak.", args[0], args[1], args[2], args[3]);
}
}
}
}
private static String getPrettyClassName(Class<?> clazz) {
String name = clazz.getCanonicalName();
if (name == null) {
name = clazz.getName();
}
return name;
}
/**
* @param o object to test, may be null
* @return <code>true</code> if o has been loaded by the current classloader or one of its descendants.
*/
private static boolean loadedByThisOrChild(Object o, ClassLoader cl) {
if (o == null) {
return false;
}
Class<?> clazz;
if (o instanceof Class) {
clazz = (Class<?>) o;
} else {
clazz = o.getClass();
}
ClassLoader clazzClassloader = clazz.getClassLoader();
while (clazzClassloader != null) {
if (clazzClassloader == cl) {
return true;
}
clazzClassloader = clazzClassloader.getParent();
}
if (o instanceof Collection<?>) {
Iterator<?> iter = ((Collection<?>) o).iterator();
try {
while (iter.hasNext()) {
Object entry = iter.next();
if (loadedByThisOrChild(entry, cl)) {
return true;
}
}
} catch (ConcurrentModificationException e) {
fail(format("Failed to fully check the entries in an instance of [%s] for potential memory leaks",
clazz.getName()));
}
}
return false;
}
}