package org.hibernate.tools.test.util; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.net.URL; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.HashMap; import java.util.Properties; public class JdbcUtil { static HashMap<Object, Connection> CONNECTION_TABLE = new HashMap<>(); public static Properties getConnectionProperties() { Properties properties = new Properties(); InputStream inputStream = Thread .currentThread() .getContextClassLoader() .getResourceAsStream("hibernate.properties"); try { properties.load(inputStream); } catch (IOException e) { throw new RuntimeException(e); } Properties connectionProperties = new Properties(); connectionProperties.put( "url", properties.getProperty("hibernate.connection.url")); connectionProperties.put( "user", properties.getProperty("hibernate.connection.username")); connectionProperties.put( "password", properties.getProperty("hibernate.connection.password")); return connectionProperties; } public static void establishJdbcConnection(Object test) { try { CONNECTION_TABLE.put(test, createJdbcConnection()); } catch (SQLException e) { throw new RuntimeException(e); } } public static void releaseJdbcConnection(Object test) { Connection connection = CONNECTION_TABLE.get(test); CONNECTION_TABLE.remove(test); try { connection.close(); } catch (SQLException e) { throw new RuntimeException(e); } } public static void executeSql(Object test, String[] sqls) { try { executeSql(CONNECTION_TABLE.get(test), sqls); } catch (SQLException e) { throw new RuntimeException(e); } } public static String toIdentifier(Object test, String string) { Connection connection = CONNECTION_TABLE.get(test); try { DatabaseMetaData databaseMetaData = connection.getMetaData(); if (databaseMetaData.storesLowerCaseIdentifiers()) { return string.toLowerCase(); } else if (databaseMetaData.storesUpperCaseIdentifiers()) { return string.toUpperCase(); } else { return string; } } catch (SQLException e) { throw new RuntimeException(e); } } public static boolean isDatabaseOnline() { boolean result = false; try { Connection connection = createJdbcConnection(); result = connection.isValid(1); connection.commit(); connection.close(); } catch (SQLException e) { // this will happen when the connection cannot be created } return result; } public static void createDatabase(Object test) { establishJdbcConnection(test); executeSql(test, getSqls(test, "create.sql", "CREATE_SQL")); } public static void populateDatabase(Object test) { executeSql(test, getSqls(test, "data.sql", "DATA_SQL")); } public static void dropDatabase(Object test) { executeSql(test, getSqls(test, "drop.sql", "DROP_SQL")); releaseJdbcConnection(test); } private static String[] getSqls(Object test, String scriptName, String fieldName) { File createDatabaseScript = getSqlScript(test, scriptName); String[] sqls = null; if (createDatabaseScript != null && createDatabaseScript.exists()) { sqls = getSqlsFromFile(createDatabaseScript); } else { sqls = getSqlsFromField(test, fieldName); } return sqls; } private static String[] getSqlsFromField(Object test, String fieldName) { String[] result = new String[] {}; try { Field field = test.getClass().getDeclaredField(fieldName); field.setAccessible(true); result = (String[])field.get(null); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } return result; } private static String[] getSqlsFromFile(File file) { ArrayList<String> sqls = new ArrayList<String>(); try { FileReader fileReader = new FileReader(file); BufferedReader bufferedReader = new BufferedReader(fileReader); String line = null; while ((line = bufferedReader.readLine()) != null) { sqls.add(line); } bufferedReader.close(); } catch (IOException e) { new RuntimeException(e); } return sqls.toArray(new String[sqls.size()]); } private static File getSqlScript(Object test, String name) { File result = null; String fullName = getSqlScriptsLocation(test) + name; URL url = Thread .currentThread() .getContextClassLoader() .getResource(fullName); if (url != null) { result = new File(url.getFile()); } return result; } private static String getSqlScriptsLocation(Object test) { return test.getClass().getName().replace('.', '/') + '/'; } private static Connection createJdbcConnection() throws SQLException { Properties connectionProperties = getConnectionProperties(); String connectionUrl = (String)connectionProperties.remove("url"); return DriverManager .getDriver(connectionUrl) .connect(connectionUrl, connectionProperties); } private static void executeSql(Connection connection, String[] sqls) throws SQLException { Statement statement = connection.createStatement(); for (int i = 0; i < sqls.length; i++) { statement.execute(sqls[i]); } if (!connection.getAutoCommit()) { connection.commit(); } statement.close(); } }