package org.activityinfo.server.database;
/*
* #%L
* ActivityInfo Server
* %%
* Copyright (C) 2009 - 2013 UNICEF
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-3.0.html>.
* #L%
*/
import com.bedatadriven.rebar.sql.server.jdbc.JdbcScheduler;
import com.google.inject.Provider;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.LowerCaseDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSetBuilder;
import org.dbunit.ext.mssql.InsertIdentityOperation;
import org.dbunit.ext.mysql.MySqlConnection;
import org.dbunit.operation.DatabaseOperation;
import org.junit.internal.runners.model.MultipleFailureException;
import org.junit.runners.model.Statement;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
public class LoadDataSet extends Statement {
private static final Logger LOGGER = Logger.getLogger(LoadDataSet.class.getName());
private final Statement next;
private final Object target;
private final Provider<Connection> connectionProvider;
private final String name;
public LoadDataSet(Provider<Connection> connectionProvider, Statement next,
String name, Object target) {
this.next = next;
this.target = target;
this.connectionProvider = connectionProvider;
this.name = name;
}
@Override
public void evaluate() throws Throwable {
JdbcScheduler.get().forceCleanup();
LOGGER.info("Removing all rows");
removeAllRows();
LOGGER.info("DBUnit: loading " + name + " into the database.");
IDataSet data = loadDataSet();
List<Throwable> errors = new ArrayList<Throwable>();
errors.clear();
try {
populate(data);
next.evaluate();
} catch (Throwable e) {
errors.add(e);
}
MultipleFailureException.assertEmpty(errors);
}
private IDataSet loadDataSet() throws IOException, DataSetException {
InputStream in = target.getClass().getResourceAsStream(name);
if (in == null) {
throw new Error("Could not find resource '" + name + "'");
}
return new LowerCaseDataSet(new FlatXmlDataSetBuilder()
.setDtdMetadata(true)
.setColumnSensing(true)
.build(new InputStreamReader(in)));
}
private void populate(final IDataSet dataSet) throws DatabaseUnitException,
SQLException {
executeOperation(InsertIdentityOperation.INSERT, dataSet);
}
private void removeAllRows() {
DatabaseCleaner cleaner = new DatabaseCleaner(connectionProvider);
cleaner.clean();
}
private void executeOperation(final DatabaseOperation op,
final IDataSet dataSet) throws DatabaseUnitException, SQLException {
try (Connection connection = connectionProvider.get()) {
IDatabaseConnection dbUnitConnection = new MySqlConnection(
connection, null);
op.execute(dbUnitConnection, dataSet);
}
}
}