package org.mariadb.jdbc; import org.junit.*; import org.junit.rules.TestRule; import org.junit.rules.TestWatcher; import org.junit.runner.Description; import org.mariadb.jdbc.failover.TcpProxy; import org.mariadb.jdbc.internal.failover.AbstractMastersListener; import org.mariadb.jdbc.internal.protocol.Protocol; import org.mariadb.jdbc.internal.util.Options; import java.io.IOException; import java.lang.reflect.Method; import java.net.InetAddress; import java.net.UnknownHostException; import java.sql.*; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; import static org.junit.Assert.*; /** * Base util class. * For testing * mvn test -DdbUrl=jdbc:mariadb://localhost:3306/testj?user=root -DlogLevel=FINEST */ @Ignore public class BaseTest { protected static final String mDefUrl = "jdbc:mariadb://localhost:3306/testj?user=root"; protected static String connU; protected static String connUri; protected static String hostname; protected static int port; protected static String database; protected static String username; protected static String password; protected static String parameters; protected static boolean testSingleHost; protected static Connection sharedConnection; protected static boolean runLongTest = false; protected static boolean doPrecisionTest = true; private static Deque<String> tempTableList = new ArrayDeque<>(); private static Deque<String> tempViewList = new ArrayDeque<>(); private static Deque<String> tempProcedureList = new ArrayDeque<>(); private static Deque<String> tempFunctionList = new ArrayDeque<>(); private static TcpProxy proxy = null; private static UrlParser urlParser; @Rule public TestRule watcher = new TestWatcher() { protected void starting(Description description) { if (testSingleHost) { System.out.println("start test : " + description.getClassName() + "." + description.getMethodName()); } } //execute another query to ensure connection is stable protected void finished(Description description) { if (testSingleHost) { Random random = new Random(); int randInt = random.nextInt(); try (PreparedStatement preparedStatement = sharedConnection.prepareStatement("SELECT " + randInt)) { ResultSet rs = preparedStatement.executeQuery(); assertTrue(rs.next()); assertEquals(randInt, rs.getInt(1)); } catch (Exception e) { e.printStackTrace(); Assert.fail("Prepare after test fail for " + description.getClassName() + "." + description.getMethodName()); } } } protected void succeeded(Description description) { if (testSingleHost) { System.out.println("finished test success : " + description.getClassName() + "." + description.getMethodName()); } } protected void failed(Throwable throwable, Description description) { if (testSingleHost) { System.out.println("finished test failed : " + description.getClassName() + "." + description.getMethodName()); } } }; /** * Initialization. * * @throws SQLException exception */ @BeforeClass() public static void beforeClassBaseTest() throws SQLException { String url = System.getProperty("dbUrl", mDefUrl); runLongTest = Boolean.getBoolean(System.getProperty("runLongTest", "false")); testSingleHost = Boolean.parseBoolean(System.getProperty("testSingleHost", "true")); if (testSingleHost) { urlParser = UrlParser.parse(url); if (urlParser.getHostAddresses().size() > 0) { hostname = urlParser.getHostAddresses().get(0).host; port = urlParser.getHostAddresses().get(0).port; } else { hostname = null; port = 3306; } database = urlParser.getDatabase(); username = urlParser.getUsername(); password = urlParser.getPassword(); int separator = url.indexOf("//"); String urlSecondPart = url.substring(separator + 2); int dbIndex = urlSecondPart.indexOf("/"); int paramIndex = urlSecondPart.indexOf("?"); String additionalParameters; if ((dbIndex < paramIndex && dbIndex < 0) || (dbIndex > paramIndex && paramIndex > -1)) { additionalParameters = urlSecondPart.substring(paramIndex); } else if ((dbIndex < paramIndex && dbIndex > -1) || (dbIndex > paramIndex && paramIndex < 0)) { additionalParameters = urlSecondPart.substring(dbIndex); } else { additionalParameters = null; } if (additionalParameters != null) { String regex = "(\\/[^\\?]*)(\\?.+)*|(\\?[^\\/]*)(\\/.+)*"; Pattern pattern = Pattern.compile(regex); Matcher matcher = pattern.matcher(additionalParameters); if (matcher.find()) { String options1 = (matcher.group(2) != null) ? matcher.group(2).substring(1) : ""; String options2 = (matcher.group(3) != null) ? matcher.group(3).substring(1) : ""; parameters = (!options1.equals("")) ? options1 : options2; } } else { parameters = null; } setUri(); sharedConnection = DriverManager.getConnection(url); String dbVersion = sharedConnection.getMetaData().getDatabaseProductVersion(); doPrecisionTest = isMariadbServer() || !dbVersion.startsWith("5.5"); //MySQL 5.5 doesn't support precision } } private static void setUri() { connU = "jdbc:mariadb://" + ((hostname == null) ? "localhost" : hostname) + ":" + port + "/" + database; connUri = connU + "?user=" + username + (password != null && !"".equals(password) ? "&password=" + password : "") + (parameters != null ? "&" + parameters : ""); } /** * Destroy the test tables. * * @throws SQLException exception */ @AfterClass public static void afterClassBaseTest() throws SQLException { if (testSingleHost) { if (sharedConnection != null && !sharedConnection.isClosed()) { if (!tempViewList.isEmpty()) { Statement stmt = sharedConnection.createStatement(); String viewName; while ((viewName = tempViewList.poll()) != null) { try { stmt.execute("DROP VIEW IF EXISTS " + viewName); } catch (SQLException e) { //eat exception } } } if (!tempTableList.isEmpty()) { Statement stmt = sharedConnection.createStatement(); String tableName; while ((tableName = tempTableList.poll()) != null) { try { stmt.execute("DROP TABLE IF EXISTS " + tableName); } catch (SQLException e) { //eat exception } } } if (!tempProcedureList.isEmpty()) { Statement stmt = sharedConnection.createStatement(); String procedureName; while ((procedureName = tempProcedureList.poll()) != null) { try { stmt.execute("DROP procedure IF EXISTS " + procedureName); } catch (SQLException e) { //eat exception } } } if (!tempFunctionList.isEmpty()) { Statement stmt = sharedConnection.createStatement(); String functionName; while ((functionName = tempFunctionList.poll()) != null) { try { stmt.execute("DROP FUNCTION IF EXISTS " + functionName); } catch (SQLException e) { //eat exception } } } try { sharedConnection.close(); } catch (SQLException e) { e.printStackTrace(); } } } } // common function for logging information static void logInfo(String message) { System.out.println(message); } /** * Create a table that will be detroyed a the end of tests. * * @param tableName table name * @param tableColumns table columns * @throws SQLException exception */ public static void createTable(String tableName, String tableColumns) throws SQLException { createTable(tableName, tableColumns, null); } /** * Create a table that will be detroyed a the end of tests. * * @param tableName table name * @param tableColumns table columns * @param engine engine type * @throws SQLException exception */ public static void createTable(String tableName, String tableColumns, String engine) throws SQLException { if (testSingleHost) { Statement stmt = sharedConnection.createStatement(); stmt.execute("drop table if exists " + tableName); stmt.execute("create table " + tableName + " (" + tableColumns + ") " + ((engine != null) ? engine : "")); if (!tempFunctionList.contains(tableName)) tempTableList.add(tableName); } } /** * Create a view that will be detroyed a the end of tests. * * @param viewName table name * @param tableColumns table columns * @throws SQLException exception */ public static void createView(String viewName, String tableColumns) throws SQLException { if (testSingleHost) { Statement stmt = sharedConnection.createStatement(); stmt.execute("drop view if exists " + viewName); stmt.execute("create view " + viewName + " AS (" + tableColumns + ") "); if (!tempViewList.contains(viewName)) tempViewList.add(viewName); } } /** * Create procedure that will be delete on end of test. * * @param name procedure name * @param body procecure body * @throws SQLException exception */ public static void createProcedure(String name, String body) throws SQLException { if (testSingleHost) { Statement stmt = sharedConnection.createStatement(); stmt.execute("drop procedure IF EXISTS " + name); stmt.execute("create procedure " + name + body); if (!tempProcedureList.contains(name)) tempProcedureList.add(name); } } /** * Create function that will be delete on end of test. * * @param name function name * @param body function body * @throws SQLException exception */ public static void createFunction(String name, String body) throws SQLException { if (testSingleHost) { Statement stmt = sharedConnection.createStatement(); stmt.execute("drop function IF EXISTS " + name); stmt.execute("create function " + name + body); if (!tempFunctionList.contains(name)) tempFunctionList.add(name); } } /** * Create a connection with proxy. * * @param info additionnal properties * @return a proxyfied connection * @throws SQLException if any error occur */ public Connection createProxyConnection(Properties info) throws SQLException { UrlParser tmpUrlParser = UrlParser.parse(connUri); username = tmpUrlParser.getUsername(); hostname = tmpUrlParser.getHostAddresses().get(0).host; String sockethosts = ""; HostAddress hostAddress; try { hostAddress = tmpUrlParser.getHostAddresses().get(0); proxy = new TcpProxy(hostAddress.host, hostAddress.port); sockethosts += "address=(host=localhost)(port=" + proxy.getLocalPort() + ")" + ((hostAddress.type != null) ? "(type=" + hostAddress.type + ")" : ""); } catch (IOException e) { e.printStackTrace(); } return openConnection("jdbc:mariadb://" + sockethosts + "/" + connUri.split("/")[3], info); } /** * Stop proxy, and restart it after a certain amount of time. * * @param millissecond milliseconds */ public void stopProxy(long millissecond) { proxy.restart(millissecond); } /** * Stop proxy. */ public void stopProxy() { proxy.stop(); } /** * Restart proxy. */ public void restartProxy() { proxy.restart(); } /** * Clean proxies. * * @throws SQLException exception */ public void closeProxy() throws SQLException { try { proxy.stop(); } catch (Exception e) { //Eat exception } } @Before public void init() throws SQLException { Assume.assumeTrue(testSingleHost); } /** * Permit to assure that host are not in a blacklist after a test. * * @param connection connection */ public void assureBlackList(Connection connection) { AbstractMastersListener.clearBlacklist(); } protected Protocol getProtocolFromConnection(Connection conn) throws Throwable { Method getProtocol = MariaDbConnection.class.getDeclaredMethod("getProtocol", new Class[0]); getProtocol.setAccessible(true); Object obj = getProtocol.invoke(conn); return (Protocol) obj; } protected void setHostname(String hostname) throws SQLException { BaseTest.hostname = hostname; setUri(); setConnection(); } protected void setPort(int port) throws SQLException { BaseTest.port = port; setUri(); setConnection(); } protected void setDatabase(String database) throws SQLException { BaseTest.database = database; BaseTest.setUri(); setConnection(); } protected void setUsername(String username) throws SQLException { BaseTest.username = username; setUri(); setConnection(); } protected void setPassword(String password) throws SQLException { BaseTest.password = password; setUri(); setConnection(); } protected Connection setBlankConnection(String parameters) throws SQLException { return openConnection(connU + "?user=" + username + (password != null && !"".equals(password) ? "&password=" + password : "") + parameters, null); } protected Connection setConnection() throws SQLException { return openConnection(connUri, null); } protected Connection setConnection(Map<String, String> props) throws SQLException { Properties info = new Properties(); for (String key : props.keySet()) { info.setProperty(key, props.get(key)); } return openConnection(connU, info); } protected Connection setConnection(Properties info) throws SQLException { return openConnection(connUri, info); } protected Connection setConnection(String parameters) throws SQLException { return openConnection(connUri + parameters, null); } protected Connection setConnection(String additionnallParameters, String database) throws SQLException { String connU = "jdbc:mariadb://" + ((hostname == null) ? "localhost" : hostname) + ":" + port + "/" + database; String connUri = connU + "?user=" + username + (password != null && !"".equals(password) ? "&password=" + password : "") + (parameters != null ? "&" + parameters : ""); return openConnection(connUri + additionnallParameters, null); } /** * Permit to reconstruct a connection. * * @param uri base uri * @param info additionnal properties * @return A connection * @throws SQLException is any error occur */ public Connection openConnection(String uri, Properties info) throws SQLException { if (info == null) { return DriverManager.getConnection(uri); } else { return DriverManager.getConnection(uri, info); } } protected Connection openNewConnection(String url) throws SQLException { return DriverManager.getConnection(url); } protected Connection openNewConnection(String url, Properties info) throws SQLException { return DriverManager.getConnection(url, info); } boolean checkMaxAllowedPacketMore8m(String testName) throws SQLException { Statement st = sharedConnection.createStatement(); ResultSet rs = st.executeQuery("select @@max_allowed_packet"); rs.next(); long maxAllowedPacket = rs.getLong(1); rs = st.executeQuery("select @@innodb_log_file_size"); rs.next(); long innodbLogFileSize = rs.getLong(1); if (maxAllowedPacket < 8 * 1024 * 1024L) { System.out.println("test '" + testName + "' skipped due to server variable max_allowed_packet < 8M"); return false; } if (innodbLogFileSize < 80 * 1024 * 1024L) { System.out.println("test '" + testName + "' skipped due to server variable innodb_log_file_size < 80M"); return false; } return true; } boolean checkMaxAllowedPacketMore20m(String testName) throws SQLException { return checkMaxAllowedPacketMore20m(testName, true); } boolean checkMaxAllowedPacketMore20m(String testName, boolean displayMessage) throws SQLException { Statement st = sharedConnection.createStatement(); ResultSet rs = st.executeQuery("select @@max_allowed_packet"); rs.next(); long maxAllowedPacket = rs.getLong(1); rs = st.executeQuery("select @@innodb_log_file_size"); rs.next(); long innodbLogFileSize = rs.getLong(1); if (maxAllowedPacket < 20 * 1024 * 1024L) { if (displayMessage) { System.out.println("test '" + testName + "' skipped due to server variable max_allowed_packet < 20M"); } return false; } if (innodbLogFileSize < 200 * 1024 * 1024L) { if (displayMessage) { System.out.println("test '" + testName + "' skipped due to server variable innodb_log_file_size < 200M"); } return false; } return true; } boolean checkMaxAllowedPacketMore40m(String testName) throws SQLException { return checkMaxAllowedPacketMore40m(testName, true); } boolean checkMaxAllowedPacketMore40m(String testName, boolean displayMsg) throws SQLException { Statement st = sharedConnection.createStatement(); ResultSet rs = st.executeQuery("select @@max_allowed_packet"); rs.next(); long maxAllowedPacket = rs.getLong(1); rs = st.executeQuery("select @@innodb_log_file_size"); rs.next(); long innodbLogFileSize = rs.getLong(1); if (maxAllowedPacket < 40 * 1024 * 1024L) { if (displayMsg) { System.out.println("test '" + testName + "' skipped due to server variable max_allowed_packet < 40M"); } return false; } if (innodbLogFileSize < 400 * 1024 * 1024L) { if (displayMsg) { System.out.println("test '" + testName + "' skipped due to server variable innodb_log_file_size < 400M"); } return false; } return true; } //does the user have super privileges or not? boolean hasSuperPrivilege(String testName) throws SQLException { boolean superPrivilege = false; try (Statement st = sharedConnection.createStatement()) { // first test for specific user and host combination try (ResultSet rs = st.executeQuery("SELECT Super_Priv FROM mysql.user WHERE user = '" + username + "' AND host = '" + hostname + "'")) { if (rs.next()) { superPrivilege = (rs.getString(1).equals("Y")); } else { // then check for user on whatever (%) host try (ResultSet rs2 = st.executeQuery("SELECT Super_Priv FROM mysql.user WHERE user = '" + username + "' AND host = '%'")) { if (rs2.next()) superPrivilege = (rs2.getString(1).equals("Y")); } } } } if (!superPrivilege) { System.out.println("test '" + testName + "' skipped because user '" + username + "' doesn't have SUPER privileges"); } return superPrivilege; } //is the connection local? boolean isLocalConnection(String testName) { boolean isLocal = false; try { if (InetAddress.getByName(hostname).isAnyLocalAddress() || InetAddress.getByName(hostname).isLoopbackAddress()) { isLocal = true; } } catch (UnknownHostException e) { // for some reason it wasn't possible to parse the hostname // do nothing } if (!isLocal) { System.out.println("test '" + testName + "' skipped because connection is not local"); } return isLocal; } boolean haveSsl(Connection connection) { try { ResultSet rs = connection.createStatement().executeQuery("select @@have_ssl"); rs.next(); String value = rs.getString(1); return value.equals("YES"); } catch (Exception e) { return false; /* maybe 4.x ? */ } } /** * Check if version if at minimum the version asked. * * @param major database major version * @param minor database minor version * @throws SQLException exception */ public boolean minVersion(int major, int minor) throws SQLException { DatabaseMetaData md = sharedConnection.getMetaData(); int dbMajor = md.getDatabaseMajorVersion(); int dbMinor = md.getDatabaseMinorVersion(); return (dbMajor > major || (dbMajor == major && dbMinor >= minor)); } /** * Check if version if before the version asked. * * @param major database major version * @param minor database minor version * @throws SQLException exception */ public boolean strictBeforeVersion(int major, int minor) throws SQLException { DatabaseMetaData md = sharedConnection.getMetaData(); int dbMajor = md.getDatabaseMajorVersion(); int dbMinor = md.getDatabaseMinorVersion(); return (dbMajor < major || (dbMajor == major && dbMinor < minor)); } /** * Cancel if database version match. * * @param major db major version * @param minor db minor version * @throws SQLException exception */ public void cancelForVersion(int major, int minor) throws SQLException { String dbVersion = sharedConnection.getMetaData().getDatabaseProductVersion(); Assume.assumeFalse(dbVersion.startsWith(major + "." + minor)); } /** * Cancel if database version match. * * @param major db major version * @param minor db minor version * @param patch db patch version * @throws SQLException exception */ public void cancelForVersion(int major, int minor, int patch) throws SQLException { String dbVersion = sharedConnection.getMetaData().getDatabaseProductVersion(); Assume.assumeFalse(dbVersion.startsWith(major + "." + minor + "." + patch)); } void requireMinimumVersion(int major, int minor) throws SQLException { Assume.assumeTrue(minVersion(major, minor)); } /** * Check if current DB server is MariaDB. * * @return true if DB is mariadb * @throws SQLException exception */ static boolean isMariadbServer() throws SQLException { DatabaseMetaData md = sharedConnection.getMetaData(); return md.getDatabaseProductVersion().indexOf("MariaDB") != -1; } /** * Change session time zone. * * @param connection connection * @param timeZone timezone to set * @throws SQLException exception */ public void setSessionTimeZone(Connection connection, String timeZone) throws SQLException { try (Statement statement = connection.createStatement()) { statement.execute("set @@session.time_zone = '" + timeZone + "'"); } } /** * Get row number. * * @param tableName table name * @return resultset number in this table * @throws SQLException if error occur */ public int getRowCount(String tableName) throws SQLException { ResultSet rs = sharedConnection.createStatement().executeQuery("SELECT COUNT(*) FROM " + tableName); if (rs.next()) { return rs.getInt(1); } throw new SQLException("No table " + tableName + " found"); } /** * Permit to know if sharedConnection will use Prepare. * (in case dbUrl modify default options) * * @return true if PreparedStatement will use Prepare. */ public boolean sharedUsePrepare() { return urlParser.getOptions().useServerPrepStmts && !urlParser.getOptions().rewriteBatchedStatements; } /** * Permit access to current sharedConnection options. * * @return Options */ public Options sharedOptions() { return urlParser.getOptions(); } /** * Permit to know if sharedConnection use rewriteBatchedStatements. * * @return true if option rewriteBatchedStatements is set to true */ public boolean sharedIsRewrite() { return urlParser.getOptions().rewriteBatchedStatements; } /** * Has server bulk capacity. * * @return true if server has bulk capacity and option not disabled */ public boolean sharedBulkCapacity() { return urlParser.getOptions().useBatchMultiSend; } /** * Permit to know if sharedConnection use compression. * * @return true if option compression is set to true */ public boolean sharedUseCompression() { return urlParser.getOptions().useCompression; } }