package org.dayatang.dbunit; import com.mchange.v2.c3p0.ComboPooledDataSource; import org.dayatang.JdbcConstants; import org.dayatang.utils.PropertiesReader; import org.dbunit.database.DatabaseSequenceFilter; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.*; import org.dbunit.dataset.stream.IDataSetProducer; import org.dbunit.dataset.stream.StreamingDataSet; import org.dbunit.dataset.xml.FlatDtdDataSet; import org.dbunit.dataset.xml.FlatXmlDataSet; import org.dbunit.dataset.xml.FlatXmlProducer; import org.dbunit.operation.DatabaseOperation; import org.xml.sax.InputSource; import javax.sql.DataSource; import java.beans.PropertyVetoException; import java.io.*; import java.util.Properties; /** * DBUnit实用工具类。用于的在数据库和XML数据文件之间的相互输入输出。 * XML数据文件只支持FlatXml格式,则数据库中的每行数据表示为一个XML元素, * 每个列成为该元素的一个属性。 * @author yyang (<a href="mailto:gdyangyu@gmail.com">gdyangyu@gmail.com</a>) * */ public class DbUnitUtils { private DataSource dataSource; /** * 从类路径属性文件中读入JDBC连接信息 * @param resourceFile * @return */ public static DbUnitUtils configFromClasspath(String resourceFile) { Properties jdbcProperties = PropertiesReader.readPropertiesFromClasspath(resourceFile); return new DbUnitUtils(createDataSource(jdbcProperties)); } /** * 从磁盘文件中读入JDBC连接信息 * @param configFile * @return */ public static DbUnitUtils configFromFile(String configFile) { Properties jdbcProperties = PropertiesReader.readPropertiesFromFile(configFile); return new DbUnitUtils(createDataSource(jdbcProperties)); } public DbUnitUtils(DataSource dataSource) { this.dataSource = dataSource; } /** * 从XML数据文件中读入数据集,写入到数据库。数据库表中的原有数据将被清除。 * @param flatXmlDataFile XML数据文件 */ public void importDataFromClasspath(final String flatXmlDataFile) { new DbUnitTemplate(dataSource).execute(new DbUnitCallback() { @Override public void doInDbUnit(IDatabaseConnection connection) throws Exception { DatabaseOperation.CLEAN_INSERT.execute(connection, getDatasetFromFile(flatXmlDataFile)); } }); } /** * 将数据库中的数据导出到XML文件中。 * @param dir 文件存放目录 * @param fileName 生成的XML数据文件名。 */ public void exportData(final String dir, final String fileName) { new DbUnitTemplate(dataSource).execute(new DbUnitCallback() { @Override public void doInDbUnit(IDatabaseConnection connection) throws Exception { FilteredDataSet dataSet = new FilteredDataSet( new DatabaseSequenceFilter(connection), connection.createDataSet()); File parent = new File(dir); if (!parent.exists()) { parent.mkdirs(); } Writer out = new FileWriter(new File(dir, fileName)); //OutputStream out = new FileOutputStream(new File(dir, fileName)); FlatXmlDataSet.write(new SortedDataSet(dataSet), out, "UTF-8"); } }); } /** * 读取数据库表结构,生成XML数据文件的DTD文件 * @param dir 文件存放目录 * @param fileName 生成的XML数据文件名。 */ public void exportDtd(final String dir, final String fileName) { new DbUnitTemplate(dataSource).execute(new DbUnitCallback() { @Override public void doInDbUnit(IDatabaseConnection connection) throws Exception { IDataSet dataSet = new FilteredDataSet(new DatabaseSequenceFilter(connection), connection.createDataSet()); File parent = new File(dir); if (!parent.exists()) { parent.mkdirs(); } Writer out = new OutputStreamWriter(new FileOutputStream(new File(dir, fileName))); FlatDtdDataSet.write(new SortedDataSet(dataSet), out); } }); } /** * 从XML数据文件中读入数据集,写入到数据库。数据库表中的原有数据将被清除。 * @param flatXmlDataFile XML数据文件 */ public void refreshData(final String flatXmlDataFile) { new DbUnitTemplate(dataSource).execute(new DbUnitCallback() { @Override public void doInDbUnit(IDatabaseConnection connection) throws Exception { InputStream in = getClass().getResourceAsStream(flatXmlDataFile); IDataSetProducer producer = new FlatXmlProducer(new InputSource(in), false); IDataSet dataSet = new StreamingDataSet(producer); DatabaseOperation.REFRESH.execute(connection, dataSet); } }); } private IDataSet getDatasetFromFile(String flatXmlDataFile) { return getDatasetFromInputStream(getClass().getResourceAsStream(flatXmlDataFile)); } private IDataSet getDatasetFromInputStream(InputStream in) { IDataSetProducer producer = new FlatXmlProducer(new InputSource(in), false); try { return new CachedDataSet(producer); } catch (DataSetException e) { throw new RuntimeException("Cannot get dataset.", e); } } private static DataSource createDataSource(Properties jdbcProperties) { ComboPooledDataSource dataSource = new ComboPooledDataSource(); try { dataSource.setDriverClass(jdbcProperties.getProperty(JdbcConstants.JDBC_DRIVER)); } catch (PropertyVetoException e) { throw new RuntimeException("Cannot create C3P0 data source", e); } dataSource.setJdbcUrl(jdbcProperties.getProperty(JdbcConstants.JDBC_URL)); dataSource.setUser(jdbcProperties.getProperty(JdbcConstants.JDBC_USERNAME)); dataSource.setPassword(jdbcProperties.getProperty(JdbcConstants.JDBC_PASSWORD)); return dataSource; } }