/** * Licensed to Apereo under one or more contributor license * agreements. See the NOTICE file distributed with this work * for additional information regarding copyright ownership. * Apereo licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file * except in compliance with the License. You may obtain a * copy of the License at the following location: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.jasig.portlet.blackboardvcportlet.dao.impl; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Method; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.concurrent.Callable; import javax.persistence.EntityManager; import javax.persistence.EntityManagerFactory; import javax.persistence.PersistenceContext; import javax.persistence.Query; import javax.persistence.metamodel.EntityType; import javax.persistence.metamodel.Metamodel; import org.aopalliance.intercept.MethodInvocation; import org.jasig.jpa.BaseJpaDao; import org.jasig.springframework.mockito.MockitoFactoryBean; import org.junit.After; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DataIntegrityViolationException; import org.springframework.orm.jpa.JpaInterceptor; import org.springframework.transaction.TransactionStatus; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionOperations; /** * Base class for JPA based unit tests that want TX and entity manager support. * Also deletes all hibernate managed data from the database after each test execution * * @author Eric Dalquist */ public abstract class BaseJpaDaoTest { protected final Logger logger = LoggerFactory.getLogger(getClass()); @SuppressWarnings("deprecation") protected JpaInterceptor jpaInterceptor; protected TransactionOperations transactionOperations; private EntityManager entityManager; @PersistenceContext(unitName = BaseJpaDao.PERSISTENCE_UNIT_NAME) public final void setEntityManager(EntityManager entityManager) { this.entityManager = entityManager; } protected final EntityManager getEntityManager() { return this.entityManager; } @Autowired public final void setJpaInterceptor(@SuppressWarnings("deprecation") JpaInterceptor jpaInterceptor) { this.jpaInterceptor = jpaInterceptor; } @Autowired public void setTransactionOperations(TransactionOperations transactionOperations) { this.transactionOperations = transactionOperations; } /** * Deletes ALL entities from the database */ @After public final void deleteAllEntities() { final EntityManager entityManager = getEntityManager(); final EntityManagerFactory entityManagerFactory = entityManager.getEntityManagerFactory(); final Metamodel metamodel = entityManagerFactory.getMetamodel(); Set<EntityType<?>> entityTypes = new LinkedHashSet<EntityType<?>>(metamodel.getEntities()); do { final Set<EntityType<?>> failedEntitieTypes = new HashSet<EntityType<?>>(); for (final EntityType<?> entityType : entityTypes) { final String entityClassName = entityType.getBindableJavaType().getName(); try { this.executeInTransaction(new Callable<Object>() { @Override public Object call() throws Exception { logger.trace("Purging all: " + entityClassName); final Query query = entityManager.createQuery("SELECT e FROM " + entityClassName + " AS e"); final List<?> entities = query.getResultList(); logger.trace("Found " + entities.size() + " " + entityClassName + " to delete"); for (final Object entity : entities) { entityManager.remove(entity); } return null; } }); } catch (DataIntegrityViolationException e) { logger.trace("Failed to delete " + entityClassName + ". Must be a dependency of another entity"); failedEntitieTypes.add(entityType); } } entityTypes = failedEntitieTypes; } while (!entityTypes.isEmpty()); //Reset all spring managed mocks after every test MockitoFactoryBean.resetAllMocks(); } /** * Executes the callback inside of a {@link JpaInterceptor}. */ @SuppressWarnings({ "unchecked", "deprecation" }) public final <T> T execute(final Callable<T> callable) { try { return (T)this.jpaInterceptor.invoke(new MethodInvocationCallable<T>(callable)); } catch (Throwable e) { if (e instanceof RuntimeException) { throw (RuntimeException)e; } if (e instanceof Error) { throw (Error)e; } throw new RuntimeException(e); } } /** * Executes the callback inside of a {@link JpaInterceptor} inside of a {@link TransactionCallback} */ public final <T> T executeInTransaction(final Callable<T> callable) { return execute(new Callable<T>() { @Override public T call() throws Exception { return transactionOperations.execute(new TransactionCallback<T>() { @Override public T doInTransaction(TransactionStatus status) { try { return callable.call(); } catch (RuntimeException e) { throw e; } catch (Exception e) { throw new RuntimeException(e); } } }); } }); } /** * Executes the callback in a new thread inside of a {@link JpaInterceptor}. Waits for the * Thread to return. */ public final <T> T executeInThread(String name, final Callable<T> callable) { final List<RuntimeException> exception = new LinkedList<RuntimeException>(); final List<T> retVal = new LinkedList<T>(); final Thread t2 = new Thread(new Runnable() { @Override public void run() { try { final T val = execute(callable); retVal.add(val); } catch (Throwable e) { if (e instanceof RuntimeException) { exception.add((RuntimeException)e); } else { exception.add(new RuntimeException(e)); } } } }, name); t2.start(); try { t2.join(); } catch (InterruptedException e) { throw new RuntimeException(e); } if (exception.size() == 1) { throw exception.get(0); } return retVal.get(0); } private static final class MethodInvocationCallable<T> implements MethodInvocation { private final Callable<T> callable; private MethodInvocationCallable(Callable<T> callable) { this.callable = callable; } @Override public Object proceed() throws Throwable { return callable.call(); } @Override public Object getThis() { throw new UnsupportedOperationException(); } @Override public AccessibleObject getStaticPart() { throw new UnsupportedOperationException(); } @Override public Object[] getArguments() { throw new UnsupportedOperationException(); } @Override public Method getMethod() { throw new UnsupportedOperationException(); } } }