package fi.otavanopisto.pyramus; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Method; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.time.ZoneId; import java.time.ZoneOffset; import java.util.regex.Pattern; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; public abstract class AbstractIntegrationTest { @Rule public TestName testName = new TestName(); @Before public void baseSetupSql() throws Exception { String methodName = testName.getMethodName(); int paramIndex = methodName.indexOf('['); if (paramIndex > 0) { methodName = methodName.substring(0, paramIndex); } Method method = getClass().getMethod(methodName, new Class<?>[] {}); SqlBefore annotation = method.getAnnotation(SqlBefore.class); if (annotation != null) { String[] sqlFiles = annotation.value(); if (sqlFiles != null && sqlFiles.length > 0) { Connection connection = getConnection(); try { for (String sqlFile : sqlFiles) { runSql(connection, sqlFile); } connection.commit(); } finally { connection.close(); } } } } @After public void baseTearDownSql() throws Exception { String methodName = testName.getMethodName(); int paramIndex = methodName.indexOf('['); if (paramIndex > 0) { methodName = methodName.substring(0, paramIndex); } Method method = getClass().getMethod(methodName, new Class<?>[] {}); SqlAfter annotation = method.getAnnotation(SqlAfter.class); if (annotation != null) { String[] sqlFiles = annotation.value(); if (sqlFiles != null && sqlFiles.length > 0) { Connection connection = getConnection(); try { for (String sqlFile : sqlFiles) { runSql(connection, sqlFile); } connection.commit(); } finally { connection.close(); } } } } protected int getEntityCount(String entity) throws SQLException, ClassNotFoundException { Connection connection = getConnection(); Statement statement = connection.createStatement(); statement.execute("select count(*) as c from " + entity); ResultSet rs = statement.getResultSet(); if (rs.next()) return rs.getInt(1); return 0; } private void runSql(Connection connection, String file) throws IOException, SQLException { ClassLoader classLoader = getClass().getClassLoader(); InputStream sqlStream = classLoader.getResourceAsStream(file); if (sqlStream != null) { try { String sqlString = IOUtils.toString(sqlStream); Pattern commentPattern = Pattern.compile("--.*$", Pattern.MULTILINE); sqlString = commentPattern.matcher(sqlString).replaceAll(""); String[] sqls = sqlString.split(";(?=([^\']*\'[^\']*\')*[^\']*$)"); // Quote-aware split on ';' for (String sql : sqls) { sql = sql.trim(); if (StringUtils.isNotBlank(sql)) { Statement statement = connection.createStatement(); statement.execute(sql); } } } finally { sqlStream.close(); } } else { throw new FileNotFoundException(file); } } protected Connection getConnection() throws SQLException, ClassNotFoundException { Class.forName(getJdbcDriver()); return DriverManager.getConnection(getJdbcUrl(), getJdbcUsername(), getJdbcPassword()); } protected String getAppUrl() { return getAppUrl(false); } protected String getAppUrl(boolean secure) { return (secure ? "https://" : "http://") + getHost() + ':' + (secure ? getPortHttps() : getPortHttp()); } protected String getJdbcDriver() { return System.getProperty("it.jdbc.driver"); } protected String getJdbcUrl() { return System.getProperty("it.jdbc.url"); } protected String getJdbcJndi() { return System.getProperty("it.jdbc.jndi"); } protected String getJdbcUsername() { return System.getProperty("it.jdbc.username"); } protected String getJdbcPassword() { return System.getProperty("it.jdbc.password"); } protected String getHost() { return System.getProperty("it.host"); } protected int getPortHttp() { return Integer.parseInt(System.getProperty("it.port.http")); } protected int getPortHttps() { return Integer.parseInt(System.getProperty("it.port.https")); } protected String getKeystoreFile() { return System.getProperty("it.keystore.file"); } protected String getKeystoreAlias() { return System.getProperty("it.keystore.alias"); } protected String getKeystorePass() { return System.getProperty("it.keystore.storepass"); } protected OffsetDateTime getDateToOffsetDateTime(int year, int monthOfYear, int dayOfMonth) { LocalDateTime localDateTime = LocalDateTime.of(year, monthOfYear, dayOfMonth, 0, 0); ZoneId systemId = ZoneId.systemDefault(); ZoneOffset offset = systemId.getRules().getOffset(localDateTime); return localDateTime.atOffset(offset); } protected OffsetDateTime getDate(int year, int monthOfYear, int dayOfMonth) { return getDateToOffsetDateTime(year, monthOfYear, dayOfMonth); } }