package org.test4j.module.database.environment; import java.lang.reflect.Method; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import javax.sql.DataSource; import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.jdbc.datasource.DataSourceUtils; import org.springframework.transaction.PlatformTransactionManager; import org.test4j.module.Test4JException; import org.test4j.module.core.TestContext; import org.test4j.module.core.utility.MessageHelper; import org.test4j.module.database.annotations.Transactional.TransactionMode; import org.test4j.module.database.environment.typesmap.AbstractTypeMap; import org.test4j.module.database.transaction.DefaultTransactionManager; import org.test4j.module.database.transaction.TransactionManagementConfiguration; import org.test4j.module.database.transaction.TransactionManager; import org.test4j.module.database.utility.DataSourceType; import org.test4j.tools.commons.ExceptionWrapper; public abstract class BaseEnvironment implements DBEnvironment { /** * Set of possible providers of a spring * <code>PlatformTransactionManager</code> */ protected Set<TransactionManagementConfiguration> transactionManagementConfigurations = new HashSet<TransactionManagementConfiguration>(); protected final String dataSourceName; protected final String dataSourceFrom; protected DataSourceType dataSourceType; private Test4JDataSource dataSource; protected AbstractTypeMap typeMap; protected BaseEnvironment(DataSourceType dataSourceType, String dataSourceName, String dataSourceFrom) { this.dataSourceName = dataSourceName; this.dataSourceFrom = dataSourceFrom; this.dataSourceType = dataSourceType; } @Override public void setDataSource(String driver, String url, String schemas, String username, String password) { this.dataSource = new Test4JDataSource(dataSourceType, driver, url, schemas, username, password); } @Override public DataSource getDataSource() { return this.dataSource; } /** * Returns the <code>DataSource</code> that provides connection to the unit * test database. When invoked the first time, the DBMaintainer is invoked * to make sure the test database is up-to-date (if database updating is * enabled) * * @return The <code>DataSource</code> */ @Override public DataSource getDataSourceAndActivateTransactionIfNeeded() { ThreadTransactionManager currMethodConnection = threadTransactionManager.get(); if (currMethodConnection != null) { currMethodConnection.activateTransaction(); } return dataSource; } @Override public void registerTransactionManagementConfiguration(TransactionManagementConfiguration transactionManagementConfiguration) { if (transactionManagementConfiguration == null) { transactionManagementConfigurations.add(new TransactionManagementConfiguration() { @Override public boolean isApplicableFor(Object testObject) { return true; } @Override public PlatformTransactionManager getSpringPlatformTransactionManager(Object testObject) { DataSource dataSource = getDataSourceAndActivateTransactionIfNeeded(); return new DataSourceTransactionManager(dataSource); } @Override public boolean isTransactionalResourceAvailable(Object testObject) { return true; } @Override public Integer getPreference() { return 1; } }); } else { transactionManagementConfigurations.add(transactionManagementConfiguration); } } @Override public void startTransaction() { this.threadTransactionManager.set(new ThreadTransactionManager()); this.threadTransactionManager.get().startTransaction(); } @Override public void endTransaction() { ThreadTransactionManager currTransactionManager = this.threadTransactionManager.get(); if (currTransactionManager != null) { currTransactionManager.endTransaction(); this.threadTransactionManager.remove(); } } /** * 当前线程的连接 * * @return */ @Override public Connection connect() { ThreadTransactionManager currMethodConnection = threadTransactionManager.get(); if (currMethodConnection == null) { currMethodConnection = new ThreadTransactionManager(); threadTransactionManager.set(currMethodConnection); } Connection connection = currMethodConnection.getConnection(); if (connection != null) { return connection; } DataSource dataSource = this.getDataSourceAndActivateTransactionIfNeeded(); try { connection = currMethodConnection.initMethodConnection(dataSource); return connection; } catch (SQLException e) { throw new Test4JException(e); } } /** * 事务的线程管理,在大部分情况下,测试只有一个线程。但有少数情况下例外(比如多线程测试)! */ private final ThreadLocal<ThreadTransactionManager> threadTransactionManager = new ThreadLocal<ThreadTransactionManager>(); private class ThreadTransactionManager { /** * The transaction manager */ private TransactionManager transactionManager; private Object testedObject; private Method testedMethod; private Connection connection; public Connection getConnection() { if (this.testedObject != TestContext.currTestedObject() || this.testedMethod != TestContext.currTestedMethod()) { this.connection = null; } return connection; } public Connection initMethodConnection(DataSource dataSource) throws SQLException { this.testedObject = TestContext.currTestedObject(); this.testedMethod = TestContext.currTestedMethod(); connection = DataSourceUtils.doGetConnection(dataSource); return connection; } public void release() { if (this.connection == null) { return; } DataSource dataSource = getDataSourceAndActivateTransactionIfNeeded(); try { if (connection.isClosed() == false) { DataSourceUtils.doReleaseConnection(connection, dataSource); } this.connection = null; } catch (SQLException e) { throw new Test4JException( String.format("close datasource[%s] connection error.", dataSource.toString()), e); } } boolean hasActivated = false; public void activateTransaction() { if (hasActivated == false && transactionManager != null) { transactionManager.activateTransactionIfNeeded(); this.hasActivated = true; } } /** * start transaction<br> * if transaction manager does not exist yet, then create one. */ public void startTransaction() { if (transactionManager == null) { transactionManager = new DefaultTransactionManager(); transactionManager.init(transactionManagementConfigurations); } this.transactionManager.startTransaction(); } public void endTransaction() { if (this.transactionManager == null) { this.release(); } else { TransactionMode mode = DBEnvironmentFactory.getTransactionMode(); if (mode == TransactionMode.COMMIT) { this.commit(); } else if (mode == TransactionMode.ROLLBACK) { this.rollback(); } } } public void rollback() { if (this.transactionManager != null) { this.transactionManager.rollback(); } this.release(); } public void commit() { if (this.transactionManager != null) { transactionManager.commit(); } this.release(); } } /** * 是否是默认的数据源 * * @return */ public boolean isDefaultDBEnvironment() { boolean isDefault = DEFAULT_DATASOURCE_NAME.equals(this.dataSourceName) && DEFAULT_DATASOURCE_FROM.equals(dataSourceFrom); return isDefault; } /** * any processing required to turn a string into something jdbc driver can * process, can be used to clean up CRLF, externalise parameters if required * etc. */ protected String parseCommandText(String commandText, String[] vars) { return commandText; } @Override public void commit() { ThreadTransactionManager currTransactionManager = threadTransactionManager.get(); if (currTransactionManager != null) { currTransactionManager.commit(); } } @Override public void rollback() { ThreadTransactionManager currTransactionManager = threadTransactionManager.get(); if (currTransactionManager != null) { currTransactionManager.rollback(); } } @Override public int getExceptionCode(SQLException dbException) { return dbException.getErrorCode(); } /** * by default, this is set to false. * * @see org.test4j.module.database.environment.DBEnvironment#supportsOuputOnInsert() */ public boolean supportsOuputOnInsert() { return false; } @Override public PreparedStatement createStatementWithBoundFixtureSymbols(String commandText) throws SQLException { Connection connection = this.connect(); PreparedStatement cs = connection.prepareStatement(commandText); return cs; } private final Map<String, TableMeta> metas = new HashMap<String, TableMeta>(); /** * 获得数据表的元信息 * * @param table * @return * @throws Exception */ @Override public TableMeta getTableMetaData(String table) { TableMeta meta = metas.get(table); if (meta == null) { try { String query = "select * from " + table + " where 1!=1"; PreparedStatement st = this.createStatementWithBoundFixtureSymbols(query); ResultSet rs = st.executeQuery(); meta = new TableMeta(table, rs.getMetaData(), this); metas.put(table, meta); } catch (Exception e) { throw ExceptionWrapper.getUndeclaredThrowableExceptionCaused(e); } } return meta; } @Override public Object getDefaultValue(String javaType) { Object value = this.typeMap.getDefaultValue(javaType); return value; } @Override public Object toObjectValue(String input, String javaType) { try { Object value = this.typeMap.toObjectByType(input, javaType); return value; } catch (Exception e) { MessageHelper.info("convert input[" + input + "] to type[" + javaType + "] error, so return input value.\n" + e.getMessage()); return input; } } /** * {@inheritDoc} <br> * <br> */ @Override @SuppressWarnings("rawtypes") public Object converToSqlValue(Object value) { if (value instanceof Enum) { return ((Enum) value).name(); } return value; } }