/* * Copyright Terracotta, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.ehcache.impl.serialization; import java.io.IOException; import java.io.InputStream; import java.util.Collections; import java.util.Deque; import java.util.HashMap; import java.util.LinkedList; import java.util.Map; import java.util.Map.Entry; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.commons.Remapper; import org.objectweb.asm.commons.RemappingClassAdapter; /** * * @author cdennis */ public final class SerializerTestUtilities { private SerializerTestUtilities() { //no instances please } public static ClassLoader createClassNameRewritingLoader(Class<?> initial, Class<?> ... more) { ClassLoader loader = initial.getClassLoader(); Map<String, String> remapping = new HashMap<String, String>(); remapping.putAll(createRemappings(initial)); for (Class<?> klazz : more) { remapping.putAll(createRemappings(klazz)); } return new RewritingClassloader(loader, remapping); } private static Map<String, String> createRemappings(Class<?> initial) { HashMap<String, String> remappings = new HashMap<String, String>(); remappings.put(initial.getName(), newClassName(initial)); for (Class<?> inner : initial.getDeclaredClasses()) { remappings.put(inner.getName(), newClassName(inner)); } if (initial.isEnum()) { for (Object e : initial.getEnumConstants()) { Class<?> eClass = e.getClass(); if (eClass != initial) { remappings.put(eClass.getName(), newClassName(eClass)); } } } return remappings; } public static String newClassName(Class<?> initial) { String initialName = initial.getName(); int lastUnderscore = initialName.lastIndexOf('_'); if (lastUnderscore == -1) { return initialName; } else { int nextDollar = initialName.indexOf('$', lastUnderscore); if (nextDollar == -1) { return initialName.substring(0, lastUnderscore); } else { return initialName.substring(0, lastUnderscore).concat(initialName.substring(nextDollar)); } } } private static final ThreadLocal<Deque<ClassLoader>> tcclStacks = new ThreadLocal<Deque<ClassLoader>>() { @Override protected Deque<ClassLoader> initialValue() { return new LinkedList<ClassLoader>(); } }; public static void pushTccl(ClassLoader loader) { tcclStacks.get().push(Thread.currentThread().getContextClassLoader()); Thread.currentThread().setContextClassLoader(loader); } public static void popTccl() { Thread.currentThread().setContextClassLoader(tcclStacks.get().pop()); } static class RewritingClassloader extends ClassLoader { private final Map<String, String> remappings; RewritingClassloader(ClassLoader parent, Map<String, String> remappings) { super(parent); this.remappings = Collections.unmodifiableMap(new HashMap<String, String>(remappings)); } @Override protected synchronized Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { Class<?> c = findLoadedClass(name); if (c == null) { if (remappings.containsValue(name)) { c = findClass(name); if (resolve) { resolveClass(c); } } else { return super.loadClass(name, resolve); } } return c; } @Override protected Class<?> findClass(String name) throws ClassNotFoundException { for (Entry<String, String> mapping : remappings.entrySet()) { if (name.equals(mapping.getValue())) { String path = mapping.getKey().replace('.', '/').concat(".class"); try { InputStream resource = getResourceAsStream(path); try { ClassReader reader = new ClassReader(resource); ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_MAXS); ClassVisitor visitor = new RemappingClassAdapter(writer, new Remapper() { @Override public String map(String from) { String to = remappings.get(from.replace('/', '.')); if (to == null) { return from; } else { return to.replace('.', '/'); } } }); reader.accept(visitor, ClassReader.EXPAND_FRAMES); byte[] classBytes = writer.toByteArray(); return defineClass(name, classBytes, 0, classBytes.length); } finally { resource.close(); } } catch (IOException e) { throw new ClassNotFoundException("IOException while loading", e); } } } return super.findClass(name); } } }