/* * Hibernate, Relational Persistence for Idiomatic Java * * License: GNU Lesser General Public License (LGPL), version 2.1 or later. * See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>. */ package org.hibernate.jpa.test.enhancement; import java.io.IOException; import java.io.InputStream; import java.lang.instrument.IllegalClassFormatException; import java.util.List; import org.hibernate.bytecode.enhance.spi.DefaultEnhancementContext; import org.hibernate.bytecode.enhance.spi.EnhancementContext; import org.hibernate.bytecode.enhance.spi.UnloadedClass; import org.hibernate.jpa.internal.enhance.EnhancingClassTransformerImpl; /** * @author Emmanuel Bernard * @author Dustin Schultz */ public class InstrumentedClassLoader extends ClassLoader { private List<String> entities; public InstrumentedClassLoader(ClassLoader parent) { super( parent ); } @Override public Class<?> loadClass(String name) throws ClassNotFoundException { // Do not instrument the following packages if ( name != null && ( name.startsWith( "java.lang." ) || name.startsWith( "java.util." ) ) ) { return getParent().loadClass( name ); } Class c = findLoadedClass( name ); if ( c != null ) { return c; } byte[] transformed = loadClassBytes( name ); return defineClass( name, transformed, 0, transformed.length ); } /** * Specialized {@link ClassLoader#loadClass(String)} that returns the class * as a byte array. * * @param name * * @return * * @throws ClassNotFoundException */ public byte[] loadClassBytes(String name) throws ClassNotFoundException { InputStream is = this.getResourceAsStream( name.replace( ".", "/" ) + ".class" ); if ( is == null ) { throw new ClassNotFoundException( name ); } byte[] buffer = new byte[409600]; byte[] originalClass = new byte[0]; int r = 0; try { r = is.read( buffer ); } catch (IOException e) { throw new ClassNotFoundException( name + " not found", e ); } while ( r >= buffer.length ) { byte[] temp = new byte[originalClass.length + buffer.length]; System.arraycopy( originalClass, 0, temp, 0, originalClass.length ); System.arraycopy( buffer, 0, temp, originalClass.length, buffer.length ); originalClass = temp; } if ( r != -1 ) { byte[] temp = new byte[originalClass.length + r]; System.arraycopy( originalClass, 0, temp, 0, originalClass.length ); System.arraycopy( buffer, 0, temp, originalClass.length, r ); originalClass = temp; } try { is.close(); } catch (IOException e) { throw new ClassNotFoundException( name + " not found", e ); } EnhancingClassTransformerImpl t = new EnhancingClassTransformerImpl( getEnhancementContext( getParent(), entities ) ); try { byte[] transformed = t.transform( getParent(), name, null, null, originalClass ); if ( transformed == null ) { return originalClass; } else { return transformed; } } catch (IllegalClassFormatException e) { throw new ClassNotFoundException( name + " not found", e ); } } public void setEntities(List<String> entities) { this.entities = entities; } public EnhancementContext getEnhancementContext(final ClassLoader cl, final List<String> entities) { return new DefaultEnhancementContext() { @Override public ClassLoader getLoadingClassLoader() { return cl; } @Override public boolean isEntityClass(UnloadedClass classDescriptor) { return entities.contains( classDescriptor.getName() ) && super.isEntityClass( classDescriptor ); } @Override public boolean isCompositeClass(UnloadedClass classDescriptor) { return entities.contains( classDescriptor.getName() ) && super.isCompositeClass( classDescriptor ); } }; } }