/* * Hibernate, Relational Persistence for Idiomatic Java * * License: GNU Lesser General Public License (LGPL), version 2.1 or later. * See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>. */ package org.hibernate.testing.bytecode.enhancement; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.util.Arrays; import javassist.ClassPool; import javassist.CtClass; import javassist.LoaderClassPath; import org.hibernate.HibernateException; import org.hibernate.LockMode; import org.hibernate.bytecode.enhance.spi.EnhancementContext; import org.hibernate.bytecode.enhance.spi.Enhancer; import org.hibernate.cfg.Environment; import org.hibernate.engine.internal.MutableEntityEntryFactory; import org.hibernate.engine.spi.EntityEntry; import org.hibernate.engine.spi.SelfDirtinessTracker; import org.hibernate.engine.spi.Status; import org.hibernate.internal.CoreLogging; import org.hibernate.internal.CoreMessageLogger; import org.hibernate.testing.junit4.BaseUnitTestCase; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * utility class to use in bytecode enhancement tests * * @author Steve Ebersole * @author Luis Barreiro */ public abstract class EnhancerTestUtils extends BaseUnitTestCase { private static final CoreMessageLogger log = CoreLogging.messageLogger( EnhancerTestUtils.class ); private static String workingDir = System.getProperty( "java.io.tmpdir" ); /** * method that performs the enhancement of a class * also checks the signature of enhanced entities methods using 'javap' decompiler */ public static Class<?> enhanceAndDecompile(Class<?> classToEnhance, ClassLoader cl) throws Exception { CtClass entityCtClass = generateCtClassForAnEntity( classToEnhance ); byte[] original = entityCtClass.toBytecode(); byte[] enhanced = Environment.getBytecodeProvider().getEnhancer( new EnhancerTestContext() ).enhance( entityCtClass.getName(), original ); assertFalse( "entity was not enhanced", enhanced == null ); log.infof( "enhanced entity [%s]", entityCtClass.getName() ); ClassPool cp = new ClassPool( false ); cp.appendClassPath( new LoaderClassPath( cl ) ); CtClass enhancedCtClass = cp.makeClass( new ByteArrayInputStream( enhanced ) ); enhancedCtClass.debugWriteFile( workingDir ); DecompileUtils.decompileDumpedClass( workingDir, classToEnhance.getName() ); Class<?> enhancedClass = enhancedCtClass.toClass( cl, EnhancerTestUtils.class.getProtectionDomain() ); assertNotNull( enhancedClass ); return enhancedClass; } private static CtClass generateCtClassForAnEntity(Class<?> entityClassToEnhance) throws Exception { ClassPool cp = new ClassPool( false ); ClassLoader cl = EnhancerTestUtils.class.getClassLoader(); return cp.makeClass( cl.getResourceAsStream( entityClassToEnhance.getName().replace( '.', '/' ) + ".class" ) ); } /* --- */ public static <T extends EnhancerTestTask> void runEnhancerTestTask(Class<T> task) { runEnhancerTestTask( task, new EnhancerTestContext() ); } public static <T extends EnhancerTestTask> void runEnhancerTestTask(Class<T> task, EnhancementContext context) { EnhancerTestTask taskObject = null; ClassLoader defaultCL = Thread.currentThread().getContextClassLoader(); try { ClassLoader cl = EnhancerTestUtils.getEnhancerClassLoader( context, task.getPackage().getName() ); EnhancerTestUtils.setupClassLoader( cl, task ); EnhancerTestUtils.setupClassLoader( cl, task.newInstance().getAnnotatedClasses() ); Thread.currentThread().setContextClassLoader( cl ); taskObject = cl.loadClass( task.getName() ).asSubclass( EnhancerTestTask.class ).newInstance(); taskObject.prepare(); taskObject.execute(); } catch (Exception e) { throw new HibernateException( "could not execute task", e ); } finally { try { if ( taskObject != null ) { taskObject.complete(); } } catch (Throwable ignore) { } Thread.currentThread().setContextClassLoader( defaultCL ); } } private static void setupClassLoader(ClassLoader cl, Class<?>... classesToLoad) { for ( Class<?> classToLoad : classesToLoad ) { try { cl.loadClass( classToLoad.getName() ); } catch (ClassNotFoundException e) { e.printStackTrace(); } } } private static ClassLoader getEnhancerClassLoader(EnhancementContext context, String packageName) { return new ClassLoader() { private Enhancer enhancer = Environment.getBytecodeProvider().getEnhancer( context ); @SuppressWarnings("ResultOfMethodCallIgnored") @Override public Class<?> loadClass(String name) throws ClassNotFoundException { if ( !name.startsWith( packageName ) ) { return getParent().loadClass( name ); } Class c = findLoadedClass( name ); if ( c != null ) { return c; } InputStream is = getResourceAsStream( name.replace( '.', '/' ) + ".class" ); if ( is == null ) { throw new ClassNotFoundException( name + " not found" ); } try { byte[] original = new byte[is.available()]; new BufferedInputStream( is ).read( original ); byte[] enhanced = enhancer.enhance( name, original ); if ( enhanced != null ) { File f = new File( workingDir + File.separator + name.replace( ".", File.separator ) + ".class" ); f.getParentFile().mkdirs(); f.createNewFile(); FileOutputStream out = new FileOutputStream( f ); out.write( enhanced ); out.close(); } else { enhanced = original; } return defineClass( name, enhanced, 0, enhanced.length ); } catch (Throwable t) { throw new ClassNotFoundException( name + " not found", t ); } finally { try { is.close(); } catch (IOException ignore) { } } } }; } public static Object getFieldByReflection(Object entity, String fieldName) { try { Field field = entity.getClass().getDeclaredField( fieldName ); field.setAccessible( true ); return field.get( entity ); } catch (NoSuchFieldException e) { fail( "Fail to get field '" + fieldName + "' in entity " + entity ); } catch (IllegalAccessException e) { fail( "Fail to get field '" + fieldName + "' in entity " + entity ); } return null; } /** * clears the dirty set for an entity */ public static void clearDirtyTracking(Object entityInstance) { ( (SelfDirtinessTracker) entityInstance ).$$_hibernate_clearDirtyAttributes(); } /** * compares the dirty fields of an entity with a set of expected values */ public static void checkDirtyTracking(Object entityInstance, String... dirtyFields) { SelfDirtinessTracker selfDirtinessTracker = (SelfDirtinessTracker) entityInstance; assertEquals( dirtyFields.length > 0, selfDirtinessTracker.$$_hibernate_hasDirtyAttributes() ); String[] tracked = selfDirtinessTracker.$$_hibernate_getDirtyAttributes(); assertEquals( dirtyFields.length, tracked.length ); assertTrue( Arrays.asList( tracked ).containsAll( Arrays.asList( dirtyFields ) ) ); } public static EntityEntry makeEntityEntry() { return MutableEntityEntryFactory.INSTANCE.createEntityEntry( Status.MANAGED, null, null, 1, null, LockMode.NONE, false, null, false, null ); } }