package org.buddycloud.channelserver.db.jdbc;
import static org.junit.Assert.assertEquals;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
public class DatabaseTester {
public class Assertions {
private DatabaseTester tester;
protected Assertions(final DatabaseTester tester) {
this.tester = tester;
}
/**
* Asserts that the table contains exactly one row with field=>value pairs
* matching the given map.
* @param tableName the table to interrogate.
* @param values a map of field=>value of the expected values
* @throws SQLException
*/
public void assertTableContains(final String tableName, final Map<String, Object> values) throws SQLException {
assertTableContains(tableName, values, 1);
}
/**
* Asserts that the table contains the specified number of rows with field=>value pairs
* matching the given map.
* @param tableName the table to interrogate.
* @param values a map of field=>value of the expected values
* @param expectedRows the number of rows we expect to match exactly
* @throws SQLException
*/
public void assertTableContains(final String tableName, final Map<String, Object> values, final int expectedRows) throws SQLException {
Connection conn = tester.getConnection();
// We will rebuild the values as a list so we can have guaranteed ordering
List<Object> valueList = new ArrayList<Object>();
StringBuilder sql = new StringBuilder("SELECT COUNT(*) FROM \"");
sql.append(tableName);
sql.append("\" WHERE TRUE");
for (Entry<String, Object> field : values.entrySet()) {
valueList.add(field.getValue());
sql.append(" AND \"");
sql.append(field.getKey());
sql.append("\" = ?");
}
sql.append(";");
PreparedStatement stmt = conn.prepareStatement(sql.toString());
for (int i = 0; i < valueList.size(); ++i) {
stmt.setObject(i + 1, valueList.get(i));
}
ResultSet rs = stmt.executeQuery();
rs.next();
assertEquals("Expected query " + sql.toString() + " to return " + expectedRows, expectedRows, rs.getInt(1));
}
}
private Connection conn;
public DatabaseTester() throws SQLException, IOException, ClassNotFoundException {
Class.forName("org.hsqldb.jdbcDriver");
Class.forName("net.sf.log4jdbc.DriverSpy");
createSchema(getConnection());
}
public void initialise() throws SQLException, IOException {
close();
createSchema(getConnection());
}
public void close() throws SQLException {
if (conn != null) {
executeDDL(conn, "SHUTDOWN");
conn = null;
}
}
public Connection getConnection() throws SQLException {
if (conn == null) {
final Connection originalConn = DriverManager.getConnection("jdbc:log4jdbc:hsqldb:mem:test", "sa", "");
conn = Mockito.spy(originalConn);
Mockito.doAnswer(new Answer<PreparedStatement>() {
@Override
public PreparedStatement answer(InvocationOnMock invocation)
throws Throwable {
String originalSQL = (String) invocation.getArguments()[0];
String replacedSQL = originalSQL.replaceAll("(\\S+) ~ \\?", "regexp_matches($1, ?)");
replacedSQL = replacedSQL.replaceAll("(\\S+) !~ \\?", "regexp_matches($1, ?) = FALSE");
return originalConn.prepareStatement(replacedSQL);
}
}).when(conn).prepareStatement(Mockito.anyString());
executeDDL(conn, "drop schema public cascade;");
}
return conn;
}
private void createSchema(final Connection conn) throws SQLException, IOException {
executeDDL(conn, "SET DATABASE SQL SYNTAX PGS TRUE;");
executeDDL(conn, "SET DATABASE SQL REFERENCES TRUE;");
loadData("base");
}
public void loadData(final String scriptName) throws SQLException, IOException {
URL url = getClass().getResource("/org/buddycloud/channelserver/testing/jdbc/scripts/" + scriptName + ".sql");
runScript(conn, new InputStreamReader(url.openStream()));
}
private void executeDDL(final Connection conn, final String ddl)
throws SQLException {
Statement stmt = conn.createStatement();
stmt.executeUpdate(ddl);
stmt.close();
}
private void runScript(final Connection conn, final Reader script) throws SQLException, IOException {
// Now read line bye line
BufferedReader d = new BufferedReader(script);
String thisLine, sqlQuery;
Statement stmt = conn.createStatement();
sqlQuery = "";
while ((thisLine = d.readLine()) != null) {
// Skip comments and empty lines
if (thisLine.length() > 0 && thisLine.charAt(0) == '-'
|| thisLine.length() == 0) {
continue;
}
sqlQuery = sqlQuery + " " + thisLine;
// If one command complete
if (sqlQuery.charAt(sqlQuery.length() - 1) == ';') {
sqlQuery = sqlQuery.replace(';', ' '); // Remove the ; since
// jdbc complains
stmt.execute(sqlQuery);
sqlQuery = "";
}
}
}
public Assertions assertions() {
return new Assertions(this);
}
}