/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.jpa.test.enhancement;
import java.io.IOException;
import java.io.InputStream;
import java.lang.instrument.IllegalClassFormatException;
import java.util.List;
import org.hibernate.bytecode.enhance.spi.DefaultEnhancementContext;
import org.hibernate.bytecode.enhance.spi.EnhancementContext;
import org.hibernate.bytecode.enhance.spi.UnloadedClass;
import org.hibernate.jpa.internal.enhance.EnhancingClassTransformerImpl;
/**
* @author Emmanuel Bernard
* @author Dustin Schultz
*/
public class InstrumentedClassLoader extends ClassLoader {
private List<String> entities;
public InstrumentedClassLoader(ClassLoader parent) {
super( parent );
}
@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
// Do not instrument the following packages
if ( name != null
&& ( name.startsWith( "java.lang." ) ||
name.startsWith( "java.util." ) ) ) {
return getParent().loadClass( name );
}
Class c = findLoadedClass( name );
if ( c != null ) {
return c;
}
byte[] transformed = loadClassBytes( name );
return defineClass( name, transformed, 0, transformed.length );
}
/**
* Specialized {@link ClassLoader#loadClass(String)} that returns the class
* as a byte array.
*
* @param name
*
* @return
*
* @throws ClassNotFoundException
*/
public byte[] loadClassBytes(String name) throws ClassNotFoundException {
InputStream is = this.getResourceAsStream( name.replace( ".", "/" ) + ".class" );
if ( is == null ) {
throw new ClassNotFoundException( name );
}
byte[] buffer = new byte[409600];
byte[] originalClass = new byte[0];
int r = 0;
try {
r = is.read( buffer );
}
catch (IOException e) {
throw new ClassNotFoundException( name + " not found", e );
}
while ( r >= buffer.length ) {
byte[] temp = new byte[originalClass.length + buffer.length];
System.arraycopy( originalClass, 0, temp, 0, originalClass.length );
System.arraycopy( buffer, 0, temp, originalClass.length, buffer.length );
originalClass = temp;
}
if ( r != -1 ) {
byte[] temp = new byte[originalClass.length + r];
System.arraycopy( originalClass, 0, temp, 0, originalClass.length );
System.arraycopy( buffer, 0, temp, originalClass.length, r );
originalClass = temp;
}
try {
is.close();
}
catch (IOException e) {
throw new ClassNotFoundException( name + " not found", e );
}
EnhancingClassTransformerImpl t = new EnhancingClassTransformerImpl( getEnhancementContext( getParent(), entities ) );
try {
byte[] transformed = t.transform(
getParent(),
name,
null,
null,
originalClass
);
if ( transformed == null ) {
return originalClass;
}
else {
return transformed;
}
}
catch (IllegalClassFormatException e) {
throw new ClassNotFoundException( name + " not found", e );
}
}
public void setEntities(List<String> entities) {
this.entities = entities;
}
public EnhancementContext getEnhancementContext(final ClassLoader cl, final List<String> entities) {
return new DefaultEnhancementContext() {
@Override
public ClassLoader getLoadingClassLoader() {
return cl;
}
@Override
public boolean isEntityClass(UnloadedClass classDescriptor) {
return entities.contains( classDescriptor.getName() ) && super.isEntityClass( classDescriptor );
}
@Override
public boolean isCompositeClass(UnloadedClass classDescriptor) {
return entities.contains( classDescriptor.getName() ) && super.isCompositeClass( classDescriptor );
}
};
}
}