/** * Copyright (c) 2000-present Liferay, Inc. All rights reserved. * * This library is free software; you can redistribute it and/or modify it under * the terms of the GNU Lesser General Public License as published by the Free * Software Foundation; either version 2.1 of the License, or (at your option) * any later version. * * This library is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more * details. */ package com.liferay.portal.kernel.util; import com.liferay.portal.kernel.log.Log; import com.liferay.portal.kernel.log.LogFactoryUtil; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; /** * @author Shuyang Zhou */ public class DefaultThreadLocalBinder implements ThreadLocalBinder { public void afterPropertiesSet() throws Exception { if (_threadLocalSources == null) { throw new IllegalArgumentException("Thread local sources is null"); } init(getClassLoader()); } @Override public void bind() { Map<ThreadLocal<?>, ?> threadLocalValues = _threadLocalValues.get(); for (Map.Entry<ThreadLocal<?>, ?> entry : threadLocalValues.entrySet()) { ThreadLocal<Object> threadLocal = (ThreadLocal<Object>)entry.getKey(); Object value = entry.getValue(); threadLocal.set(value); } } @Override public void cleanUp() { for (ThreadLocal<?> threadLocal : _threadLocals) { threadLocal.remove(); } } public ClassLoader getClassLoader() { if (_classLoader == null) { Thread currentThread = Thread.currentThread(); _classLoader = currentThread.getContextClassLoader(); } return _classLoader; } public void init(ClassLoader classLoader) throws Exception { for (Map.Entry<String, String> entry : _threadLocalSources.entrySet()) { String className = entry.getKey(); String fieldName = entry.getValue(); Class<?> clazz = classLoader.loadClass(className); Field field = ReflectionUtil.getDeclaredField(clazz, fieldName); if (!ThreadLocal.class.isAssignableFrom(field.getType())) { if (_log.isWarnEnabled()) { _log.warn( fieldName + " is not of type ThreadLocal. Skip binding."); } continue; } if (!Modifier.isStatic(field.getModifiers())) { if (_log.isWarnEnabled()) { _log.warn( fieldName + " is not a static ThreadLocal. Skip binding."); } continue; } ThreadLocal<?> threadLocal = (ThreadLocal<?>)field.get(null); if (threadLocal == null) { if (_log.isWarnEnabled()) { _log.warn(fieldName + " is not initialized. Skip binding."); } continue; } _threadLocals.add(threadLocal); } } @Override public void record() { Map<ThreadLocal<?>, Object> threadLocalValues = new HashMap<>(); for (ThreadLocal<?> threadLocal : _threadLocals) { Object value = threadLocal.get(); threadLocalValues.put(threadLocal, value); } _threadLocalValues.set(threadLocalValues); } public void setClassLoader(ClassLoader classLoader) { _classLoader = classLoader; } public void setThreadLocalSources(Map<String, String> threadLocalSources) { _threadLocalSources = threadLocalSources; } private static final Log _log = LogFactoryUtil.getLog( DefaultThreadLocalBinder.class); private static final ThreadLocal<Map<ThreadLocal<?>, ?>> _threadLocalValues = new AutoResetThreadLocal<Map<ThreadLocal<?>, ?>>( DefaultThreadLocalBinder.class + "._threadLocalValueMap") { @Override protected Map<ThreadLocal<?>, ?> copy( Map<ThreadLocal<?>, ?> threadLocalValueMap) { return threadLocalValueMap; } }; private ClassLoader _classLoader; private final Set<ThreadLocal<?>> _threadLocals = new HashSet<>(); private Map<String, String> _threadLocalSources; }