/**
* Copyright 2011-2017 Asakusa Framework Team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.asakusafw.windgate.jdbc;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.rules.TestWatcher;
import org.junit.runner.Description;
/**
* Keep a connection of H2 'in memory' Database.
*/
public class H2Resource extends TestWatcher {
private final String name;
private Class<?> context;
private Connection connection;
/**
* Creates a new instance.
* The target URL will be {@code "jdbc:h2:mem:<name>"}.
* @param name simple name of database
*/
public H2Resource(String name) {
this.name = name;
}
@Override
protected void starting(Description description) {
org.h2.Driver.load();
this.context = description.getTestClass();
this.connection = open();
boolean green = false;
try {
leakcheck();
before();
green = true;
} catch (Exception e) {
throw new AssertionError(e);
} finally {
if (green == false) {
finished(description);
}
}
}
private void leakcheck() {
try {
execute0("CREATE TABLE H2_TEST_DUPCHECK (SID IDENTITY PRIMARY KEY)");
} catch (SQLException e) {
throw new AssertionError(e);
}
}
/**
* runs before executes each test.
* @throws Exception if failed
*/
protected void before() throws Exception {
return;
}
/**
* Creates a new connection.
* @return the created connection
*/
public Connection open() {
try {
return DriverManager.getConnection(getJdbcUrl());
} catch (SQLException e) {
throw new AssertionError(e);
}
}
/**
* Returns the target URL.
* @return target URL
*/
public String getJdbcUrl() {
return "jdbc:h2:mem:" + name;
}
/**
* Returns query result columns list.
* @param sql target SQL
* @return result rows list that contains columns array
*/
public List<List<Object>> query(String sql) {
try {
return query0(sql);
} catch (Exception e) {
throw new AssertionError(e);
}
}
/**
* Returns query result columns list.
* @param sql target SQL
* @return result rows list that contains columns array
*/
public List<Object> single(String sql) {
try {
List<List<Object>> query = query0(sql);
assertThat(sql, query.size(), is(1));
return query.get(0);
} catch (Exception e) {
throw new AssertionError(e);
}
}
/**
* Count rows in the table.
* @param table target table
* @return number of row in the table, or -1 if failed
*/
public int count(String table) {
try {
List<List<Object>> r = query0(MessageFormat.format("SELECT COUNT(*) FROM {0}", table));
if (r.size() != 1) {
return -1;
}
return ((Number) r.get(0).get(0)).intValue();
} catch (Exception e) {
e.printStackTrace();
return -1;
}
}
private List<List<Object>> query0(String sql) throws SQLException {
try (Statement s = connection.createStatement()) {
ResultSet rs = s.executeQuery(sql);
ResultSetMetaData meta = rs.getMetaData();
int size = meta.getColumnCount();
List<List<Object>> results = new ArrayList<>();
while (rs.next()) {
Object[] columns = new Object[size];
for (int i = 0; i < size; i++) {
columns[i] = rs.getObject(i + 1);
}
results.add(Arrays.asList(columns));
}
return results;
}
}
/**
* Executes DML.
* @param sql DML
*/
public void execute(String sql) {
try {
execute0(sql);
} catch (Exception e) {
throw new AssertionError(e);
}
}
private void execute0(String sql) throws SQLException {
try (PreparedStatement ps = connection.prepareStatement(sql)) {
ps.execute();
connection.commit();
}
}
/**
* Executes DML in target file.
* @param sqlFile resource file
*/
public void executeFile(String sqlFile) {
String content = load(sqlFile);
execute(content);
}
private String load(String resource) {
try (InputStream source = context.getResourceAsStream(resource)) {
assertThat(resource, source, is(not(nullValue())));
StringBuilder buf = new StringBuilder();
try (Reader reader = new InputStreamReader(source, "UTF-8")) {
char[] cbuf = new char[1024];
while (true) {
int read = reader.read(cbuf);
if (read < 0) {
break;
}
buf.append(cbuf, 0, read);
}
}
return buf.toString();
} catch (Exception e) {
throw new AssertionError(e);
}
}
@Override
public void finished(Description description) {
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
throw new AssertionError(e);
}
}
}
}