/* * Copyright 1999-2015 dangdang.com. * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * </p> */ package com.dangdang.ddframe.rdb.integrate; import com.dangdang.ddframe.rdb.sharding.constants.DatabaseType; import org.apache.commons.dbcp.BasicDataSource; import org.dbunit.DatabaseUnitException; import org.dbunit.IDatabaseTester; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.ITable; import org.dbunit.dataset.xml.FlatXmlDataSetBuilder; import org.dbunit.ext.h2.H2Connection; import org.dbunit.ext.mysql.MySqlConnection; import org.dbunit.operation.DatabaseOperation; import org.h2.tools.RunScript; import org.junit.Before; import javax.sql.DataSource; import java.io.File; import java.io.InputStream; import java.io.InputStreamReader; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.dbunit.Assertion.assertEquals; public abstract class AbstractDBUnitTest { protected static final DatabaseType CURRENT_DB_TYPE = DatabaseType.H2; protected static final Map<String, DataSource> DATA_SOURCES = new HashMap<>(); private final DataBaseEnvironment dbEnv = new DataBaseEnvironment(CURRENT_DB_TYPE); @Before public void createSchema() throws SQLException { for (String each : getSchemaFiles()) { Connection conn = createDataSource(each).getConnection(); RunScript.execute(conn, new InputStreamReader(AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(each))); conn.close(); } } @Before public final void importDataSet() throws Exception { for (String each : getDataSetFiles()) { InputStream is = AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(each); IDataSet dataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(is)); IDatabaseTester databaseTester = new ShardingJdbcDatabaseTester(dbEnv.getDriverClassName(), dbEnv.getURL(getFileName(each)), dbEnv.getUsername(), dbEnv.getPassword()); databaseTester.setSetUpOperation(DatabaseOperation.CLEAN_INSERT); databaseTester.setDataSet(dataSet); databaseTester.onSetup(); } } protected abstract List<String> getSchemaFiles(); protected abstract List<String> getDataSetFiles(); protected final Map<String, DataSource> createDataSourceMap(final String dataSourceNamePattern) { Map<String, DataSource> result = new HashMap<>(getDataSetFiles().size()); for (String each : getDataSetFiles()) { result.put(String.format(dataSourceNamePattern, getFileName(each)), createDataSource(each)); } return result; } private DataSource createDataSource(final String dataSetFile) { if (DATA_SOURCES.containsKey(dataSetFile)) { return DATA_SOURCES.get(dataSetFile); } BasicDataSource result = new BasicDataSource(); result.setDriverClassName(dbEnv.getDriverClassName()); result.setUrl(dbEnv.getURL(getFileName(dataSetFile))); result.setUsername(dbEnv.getUsername()); result.setPassword(dbEnv.getPassword()); result.setMaxActive(1000); DATA_SOURCES.put(dataSetFile, result); return result; } private String getFileName(final String dataSetFile) { String fileName = new File(dataSetFile).getName(); if (-1 == fileName.lastIndexOf(".")) { return fileName; } return fileName.substring(0, fileName.lastIndexOf(".")); } protected void assertDataSet(final String expectedDataSetFile, final Connection connection, final String actualTableName, final String sql, final Object... params) throws SQLException, DatabaseUnitException { try ( Connection conn = connection; PreparedStatement ps = conn.prepareStatement(sql)) { int i = 1; for (Object each : params) { ps.setObject(i++, each); } ITable actualTable = getConnection(connection).createTable(actualTableName, ps); IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(expectedDataSetFile))); assertEquals(expectedDataSet.getTable(actualTableName), actualTable); } } protected void assertDataSet(final String expectedDataSetFile, final Connection connection, final String actualTableName, final String sql) throws SQLException, DatabaseUnitException { try (Connection conn = connection) { ITable actualTable = getConnection(conn).createQueryTable(actualTableName, sql); IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(expectedDataSetFile))); assertEquals(expectedDataSet.getTable(actualTableName), actualTable); } } private IDatabaseConnection getConnection(final Connection connection) throws DatabaseUnitException { switch (dbEnv.getDatabaseType()) { case H2: return new H2Connection(connection, "PUBLIC"); case MySQL: return new MySqlConnection(connection, "PUBLIC"); default: throw new UnsupportedOperationException(dbEnv.getDatabaseType().name()); } } }