/*
* Copyright (c) 2016 OBiBa. All rights reserved.
*
* This program and the accompanying materials
* are made available under the terms of the GNU Public License v3.0.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.obiba.core.test.spring;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.sql.SQLException;
import javax.sql.DataSource;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseDataSourceConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.ext.hsqldb.HsqldbDataTypeFactory;
import org.dbunit.operation.DatabaseOperation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestExecutionListeners;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.util.ReflectionUtils;
/**
* TestExecutionListener implementation that handles the {@link Dataset} annotation.
* <p>
* Using this listener (through {@link TestExecutionListeners} allows seeding a test
* database before executing unit test methods.
* </p>
*/
public class DbUnitAwareTestExecutionListener extends AbstractTestExecutionListener {
private static final Logger log = LoggerFactory.getLogger(DbUnitAwareTestExecutionListener.class);
@Override
public void afterTestMethod(TestContext context) throws Exception {
log.debug("{}.afterTestMethod() for context {}", getClass().getSimpleName(), context);
DbUnitTestContextAdapter adapter = new DbUnitTestContextAdapter(context);
handleElement(adapter, adapter.getTestMethod(), false);
}
@Override
public void beforeTestMethod(TestContext context) throws Exception {
log.debug("{}.beforeTestMethod() for context {}", getClass().getSimpleName(), context);
DbUnitTestContextAdapter adapter = new DbUnitTestContextAdapter(context);
handleElement(adapter, adapter.getTestMethod(), true);
}
@Override
public void prepareTestInstance(TestContext context) throws Exception {
log.debug("{}.prepareTestInstance() for context {}", getClass().getSimpleName(), context);
DbUnitTestContextAdapter adapter = new DbUnitTestContextAdapter(context);
if(adapter.getAttribute("dbUnit" + context.getTestClass()) == null) {
handleElement(adapter, adapter.getTestClass(), true);
adapter.setAttribute("dbUnit" + adapter.getTestClass(), new Object());
}
}
private void handleElement(DbUnitTestContextAdapter contextAdapter, AnnotatedElement element, boolean before)
throws Exception {
Datasets ds = element.getAnnotation(Datasets.class);
if(ds != null) {
for(Dataset dataset : ds.value()) {
handleAnnotation(contextAdapter, dataset, before);
}
} else {
Dataset da = element.getAnnotation(Dataset.class);
if(da != null) {
handleAnnotation(contextAdapter, da, before);
} else {
log.debug("No {} annotation found on element {}.", Dataset.class.getSimpleName(), element);
}
}
}
private void handleAnnotation(DbUnitTestContextAdapter contextAdapter, Dataset datasetAnnotation, boolean before)
throws Exception {
log.debug("Handling annotation {}", datasetAnnotation);
String className = contextAdapter.getTestClass().getSimpleName();
String dataSourceBeanName = datasetAnnotation.dataSourceBean();
DataSource dataSource = (DataSource) contextAdapter.getApplicationContext().getBean(dataSourceBeanName);
IDatabaseConnection connection = new DatabaseDataSourceConnection(dataSource);
connection.getConfig().setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new HsqldbDataTypeFactory());
try {
String filenames[] = datasetAnnotation.filenames();
if(filenames == null || filenames.length == 0) {
filenames = new String[] { className + ".xml" };
}
for(String filename : filenames) {
seedDatabase(contextAdapter, datasetAnnotation, before, className, connection, filename);
}
} finally {
try {
connection.close();
} catch(SQLException e) {
// Ignore so we don't hide the pertinent exception if any...
}
}
}
private void seedDatabase(DbUnitTestContextAdapter contextAdapter, Dataset datasetAnnotation, boolean before,
String className, IDatabaseConnection connection, String filename)
throws IOException, SQLException, DatabaseUnitException {
log.debug("Seeding database with dataset {}.", filename);
InputStream is = contextAdapter.getTestClass().getResourceAsStream(filename);
if(is == null) {
log.error("Test case {}: cannot find resource {}.", className, filename);
} else {
IDataSet dataset = new FlatXmlDataSet(is);
try {
getDbUnitOp(before ? datasetAnnotation.beforeOperation() : datasetAnnotation.afterOperation())
.execute(connection, dataset);
} catch(DatabaseUnitException e) {
log.error("Exception while inserting dataset filename {} for test case {}", filename, className,
e.getMessage());
throw e;
}
}
}
/**
* Converts a {@link DatasetOperationType} into a DbUnit {@link DatabaseOperation}
*
* @param type the dataset type
* @return the corresponding {@link DatabaseOperation}
*/
private DatabaseOperation getDbUnitOp(DatasetOperationType type) {
switch(type) {
case CLEAN_INSERT:
return DatabaseOperation.CLEAN_INSERT;
case DELETE:
return DatabaseOperation.DELETE;
case DELETE_ALL:
return DatabaseOperation.DELETE_ALL;
case INSERT:
return DatabaseOperation.INSERT;
case NONE:
return DatabaseOperation.NONE;
case REFRESH:
return DatabaseOperation.REFRESH;
case TRUNCATE_TABLE:
return DatabaseOperation.TRUNCATE_TABLE;
case UPDATE:
return DatabaseOperation.UPDATE;
default:
throw new IllegalArgumentException("Invalid DatasetOperationType [" + type + "]");
}
}
/**
* Adapter class to convert Spring's {@link TestContext} to a {@link DbUnitTestContext}. Since Spring 4.0 change the
* TestContext class from a class to an interface this method uses reflection.
*/
private static class DbUnitTestContextAdapter {
private static final Method GET_TEST_CLASS;
private static final Method GET_TEST_METHOD;
private static final Method GET_TEST_EXCEPTION;
private static final Method GET_APPLICATION_CONTEXT;
private static final Method GET_ATTRIBUTE;
private static final Method SET_ATTRIBUTE;
static {
try {
GET_TEST_CLASS = TestContext.class.getMethod("getTestClass");
GET_TEST_METHOD = TestContext.class.getMethod("getTestMethod");
GET_TEST_EXCEPTION = TestContext.class.getMethod("getTestException");
GET_APPLICATION_CONTEXT = TestContext.class.getMethod("getApplicationContext");
GET_ATTRIBUTE = TestContext.class.getMethod("getAttribute", String.class);
SET_ATTRIBUTE = TestContext.class.getMethod("setAttribute", String.class, Object.class);
} catch(Exception ex) {
throw new IllegalStateException(ex);
}
}
private final TestContext testContext;
private DbUnitTestContextAdapter(TestContext testContext) {
this.testContext = testContext;
}
public Class<?> getTestClass() {
return (Class<?>) ReflectionUtils.invokeMethod(GET_TEST_CLASS, testContext);
}
public Method getTestMethod() {
return (Method) ReflectionUtils.invokeMethod(GET_TEST_METHOD, testContext);
}
public Throwable getTestException() {
return (Throwable) ReflectionUtils.invokeMethod(GET_TEST_EXCEPTION, testContext);
}
public ApplicationContext getApplicationContext() {
return (ApplicationContext) ReflectionUtils.invokeMethod(GET_APPLICATION_CONTEXT, testContext);
}
public Object getAttribute(String name) {
return ReflectionUtils.invokeMethod(GET_ATTRIBUTE, testContext, name);
}
public void setAttribute(String name, Object value) {
ReflectionUtils.invokeMethod(SET_ATTRIBUTE, testContext, name, value);
}
}
}