/*******************************************************************************
* This file is part of OpenNMS(R).
*
* Copyright (C) 2008-2011 The OpenNMS Group, Inc.
* OpenNMS(R) is Copyright (C) 1999-2011 The OpenNMS Group, Inc.
*
* OpenNMS(R) is a registered trademark of The OpenNMS Group, Inc.
*
* OpenNMS(R) is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published
* by the Free Software Foundation, either version 3 of the License,
* or (at your option) any later version.
*
* OpenNMS(R) is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with OpenNMS(R). If not, see:
* http://www.gnu.org/licenses/
*
* For more information contact:
* OpenNMS(R) Licensing <license@opennms.org>
* http://www.opennms.org/
* http://www.opennms.com/
*******************************************************************************/
package org.opennms.netmgt.dao.db;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import javax.sql.DataSource;
import org.junit.Test;
import org.opennms.core.utils.LogUtils;
import org.opennms.netmgt.config.DataSourceFactory;
import org.springframework.jdbc.datasource.DelegatingDataSource;
import org.springframework.jdbc.datasource.LazyConnectionDataSourceProxy;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestExecutionListener;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.test.context.support.DependencyInjectionTestExecutionListener;
import org.springframework.util.Assert;
import com.mchange.v2.c3p0.DataSources;
import com.mchange.v2.c3p0.PooledDataSource;
/**
* This {@link TestExecutionListener} creates a temporary database and then
* registers it as the default datasource inside {@link DataSourceFactory} by
* using {@link DataSourceFactory#setInstance(DataSource)}.
*
* To change the settings for the temporary database, use the
* {@link JUnitTemporaryDatabase} annotation on the test class or method.
*
* @author <a href="mailto:brozow@opennms.org">Mathew Brozowski</a>
*/
public class TemporaryDatabaseExecutionListener extends AbstractTestExecutionListener {
private boolean m_createNewDatabases = false;
private TemporaryDatabase m_database;
private final Queue<TemporaryDatabase> m_databases = new ConcurrentLinkedQueue<TemporaryDatabase>();
@Override
public void afterTestMethod(final TestContext testContext) throws Exception {
System.err.println(String.format("TemporaryDatabaseExecutionListener.afterTestMethod(%s)", testContext));
final JUnitTemporaryDatabase jtd = findAnnotation(testContext);
if (jtd == null) return;
final PooledDataSource pds = (PooledDataSource)testContext.getAttribute("org.opennms.netmgt.dao.db.TemporaryDatabaseExecutionListener.pooledDataSource");
if (pds != null) pds.hardReset();
try {
// DON'T REMOVE THE DATABASE, just rely on the ShutdownHook to remove them instead
// otherwise you might remove the class-level database that is reused between tests.
// {@link TemporaryDatabase#createTestDatabase()}
if (m_createNewDatabases) {
final DataSource dataSource = DataSourceFactory.getInstance();
final TemporaryDatabase tempDb = findTemporaryDatabase(dataSource);
if (tempDb != null) {
tempDb.drop();
}
}
} finally {
// We must mark the application context as dirty so that the DataSourceFactoryBean is
// correctly pointed at the next temporary database.
//
// If the next database is the same as the current database, then do not rewire.
// NOTE: This does not work because the Hibernate objects need to be reinjected or they
// will reject database operations because they think that the database rows already
// exist even if they were rolled back after a previous test execution.
//
if (jtd.dirtiesContext()) {
testContext.markApplicationContextDirty();
testContext.setAttribute(DependencyInjectionTestExecutionListener.REINJECT_DEPENDENCIES_ATTRIBUTE, Boolean.TRUE);
} else {
final DataSource dataSource = DataSourceFactory.getInstance();
final TemporaryDatabase tempDb = findTemporaryDatabase(dataSource);
if (tempDb != m_databases.peek()) {
testContext.markApplicationContextDirty();
testContext.setAttribute(DependencyInjectionTestExecutionListener.REINJECT_DEPENDENCIES_ATTRIBUTE, Boolean.TRUE);
}
}
}
}
private static TemporaryDatabase findTemporaryDatabase(final DataSource dataSource) {
if (dataSource instanceof TemporaryDatabase) {
return (TemporaryDatabase) dataSource;
} else if (dataSource instanceof DelegatingDataSource) {
return findTemporaryDatabase(((DelegatingDataSource) dataSource).getTargetDataSource());
} else {
return null;
}
}
private static JUnitTemporaryDatabase findAnnotation(final TestContext testContext) {
JUnitTemporaryDatabase jtd = null;
final Method testMethod = testContext.getTestMethod();
if (testMethod != null) {
jtd = testMethod.getAnnotation(JUnitTemporaryDatabase.class);
}
if (jtd == null) {
final Class<?> testClass = testContext.getTestClass();
jtd = testClass.getAnnotation(JUnitTemporaryDatabase.class);
}
return jtd;
}
@Override
public void beforeTestMethod(final TestContext testContext) throws Exception {
System.err.println(String.format("TemporaryDatabaseExecutionListener.beforeTestMethod(%s)", testContext));
// FIXME: Is there a better way to inject the instance into the test class?
if (testContext.getTestInstance() instanceof TemporaryDatabaseAware<?>) {
System.err.println("injecting TemporaryDatabase into TemporaryDatabaseAware test: "
+ testContext.getTestInstance().getClass().getSimpleName() + "."
+ testContext.getTestMethod().getName());
injectTemporaryDatabase(testContext);
}
}
@SuppressWarnings("unchecked")
private void injectTemporaryDatabase(final TestContext testContext) {
((TemporaryDatabaseAware) testContext.getTestInstance()).setTemporaryDatabase(m_database);
}
@Override
public void beforeTestClass(final TestContext testContext) throws Exception {
// Fire up a thread pool for each CPU to create test databases
ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final JUnitTemporaryDatabase classJtd = testContext.getTestClass().getAnnotation(JUnitTemporaryDatabase.class);
final Future<TemporaryDatabase> classDs;
if (classJtd != null) {
classDs = pool.submit(new CreateNewDatabaseCallable(classJtd));
if (classJtd.reuseDatabase() == false) {
m_createNewDatabases = true;
}
} else {
classDs = null;
}
List<Future<TemporaryDatabase>> futures = new ArrayList<Future<TemporaryDatabase>>();
for (Method method : testContext.getTestClass().getMethods()) {
if (method != null) {
final JUnitTemporaryDatabase methodJtd = method.getAnnotation(JUnitTemporaryDatabase.class);
boolean methodHasTest = method.getAnnotation(Test.class) != null;
if (methodHasTest) {
// If there is a method-specific annotation, use it to create the temporary database
if (methodJtd != null) {
// Create a new database based on the method-specific annotation
Future<TemporaryDatabase> submit = pool.submit(new CreateNewDatabaseCallable(methodJtd));
Assert.notNull(submit, "pool.submit(new CreateNewDatabaseCallable(methodJtd = " + methodJtd + ")");
futures.add(submit);
} else if (classJtd != null) {
if (m_createNewDatabases) {
// Create a new database based on the test class' annotation
Future<TemporaryDatabase> submit = pool.submit(new CreateNewDatabaseCallable(classJtd));
Assert.notNull(submit, "pool.submit(new CreateNewDatabaseCallable(classJtd = " + classJtd + ")");
futures.add(submit);
} else {
// Reuse the database based on the test class' annotation
Assert.notNull(classDs, "classDs");
futures.add(classDs);
}
}
}
}
}
for (Future<TemporaryDatabase> db : futures) {
m_databases.add(db.get());
}
}
@Override
public void prepareTestInstance(final TestContext testContext) throws Exception {
System.err.println(String.format("TemporaryDatabaseExecutionListener.prepareTestInstance(%s)", testContext));
final JUnitTemporaryDatabase jtd = findAnnotation(testContext);
if (jtd == null) {
return;
}
m_database = m_databases.remove();
final PooledDataSource pooledDataSource = (PooledDataSource)DataSources.pooledDataSource(m_database);
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
try { pooledDataSource.close(); }
catch (final Throwable t) { LogUtils.debugf(this, t, "failed to close pooled data source"); }
}
});
final LazyConnectionDataSourceProxy proxy = new LazyConnectionDataSourceProxy(pooledDataSource);
DataSourceFactory.setInstance(proxy);
testContext.setAttribute("org.opennms.netmgt.dao.db.TemporaryDatabaseExecutionListener.pooledDataSource", pooledDataSource);
System.err.println(String.format("TemporaryDatabaseExecutionListener.prepareTestInstance(%s) prepared db %s", testContext, m_database.toString()));
System.err.println("Temporary Database Name: " + m_database.getTestDatabase());
}
private static class CreateNewDatabaseCallable implements Callable<TemporaryDatabase> {
private final JUnitTemporaryDatabase m_jtd;
public CreateNewDatabaseCallable(JUnitTemporaryDatabase jtd) {
m_jtd = jtd;
}
@Override
public TemporaryDatabase call() throws Exception {
return createNewDatabase(m_jtd);
}
}
private static TemporaryDatabase createNewDatabase(JUnitTemporaryDatabase jtd) throws Exception {
boolean useExisting = false;
if (jtd.useExistingDatabase() != null) {
useExisting = !jtd.useExistingDatabase().equals("");
}
final String dbName = useExisting ? jtd.useExistingDatabase() : getDatabaseName(jtd);
final TemporaryDatabase retval = ((jtd.tempDbClass()).getConstructor(String.class, Boolean.TYPE).newInstance(dbName, useExisting));
retval.setPopulateSchema(jtd.createSchema() && !useExisting);
retval.create();
return retval;
}
private static String getDatabaseName(Object hashMe) {
// Append the current object's hashcode to make this value truly unique
return String.format("opennms_test_%s_%s", System.nanoTime(), Math.abs(hashMe.hashCode()));
}
}