package org.infinispan.server.test.util.jdbc;
import static org.infinispan.server.test.util.ITestUtils.sleepForSecs;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import org.infinispan.commons.equivalence.ByteArrayEquivalence;
/**
* @author <a href="mailto:mgencur@redhat.com">Martin Gencur</a>
* @author <a href="mailto:vchepeli@redhat.com">Vitalii Chepeliuk</a>
* @since 7.0
*/
public class DBServer {
public static final long TIMEOUT = 15000;
public static DBServer create() {
return new DBServer();
}
private DBServer() {}
public DBServer(String bucketTableName, String stringTableName, String idColumnName, String dataColumnName) {
this.connectionUrl = System.getProperty("connection.url");
this.username = System.getProperty("username");
this.password = System.getProperty("password");
this.bucketTableName = bucketTableName;
this.stringTableName = stringTableName;
String driver = System.getProperty("driver.class");
if (bucketTableName != null)
bucketTable = new DBServer.TableManipulation(driver, connectionUrl, username, password, bucketTableName, idColumnName, dataColumnName);
if (stringTableName != null)
stringTable = new DBServer.TableManipulation(driver, connectionUrl, username, password, stringTableName, idColumnName, dataColumnName);
}
public String connectionUrl;
public String username;
public String password;
public String bucketTableName;
public String stringTableName;
public TableManipulation bucketTable;
public TableManipulation stringTable;
public static class TableManipulation {
private final long RETRY_TIME = 1; // in seconds
private final SimpleConnectionFactory factory;
private final String idColumnName;
private final String dataColumnName;
private String tableName;
private final String connectionUrl;
private final String username;
private final String password;
private String identifierQuoteString;
private final String getRowByKeySql;
private final String getAllRowsSql;
private final String deleteAllRowsSql;
private final String dropTableSql;
TableManipulation(String driverClass, String connectionUrl, String username, String password, String tableName,
String idColumnName, String dataColumnName) {
this.idColumnName = idColumnName;
this.dataColumnName = dataColumnName;
this.tableName = tableName;
this.connectionUrl = connectionUrl;
this.username = username;
this.password = password;
// inappropriate table name characters filter: https://github.com/infinispan/infinispan/pull/1610
this.tableName = getIdentifierQuoteString() + this.tableName.replaceAll("[^\\p{Alnum}]", "_") + getIdentifierQuoteString();
if (connectionUrl.contains("sybase")) {
this.getRowByKeySql = "SELECT " + idColumnName + ", " + dataColumnName + " FROM " + this.tableName + " WHERE " + idColumnName + " = convert(VARCHAR(255),?)";
} else if (connectionUrl.contains("postgre") || connectionUrl.contains("edb")) {
this.getRowByKeySql = "SELECT " + idColumnName + ", " + dataColumnName + " FROM " + this.tableName + " WHERE " + idColumnName + " = cast(? as VARCHAR(255))";
} else {
this.getRowByKeySql = "SELECT " + idColumnName + ", " + dataColumnName + " FROM " + this.tableName + " WHERE " + idColumnName + " = ?";
}
this.getAllRowsSql = "SELECT " + dataColumnName + "," + idColumnName + " FROM " + this.tableName;
this.deleteAllRowsSql = "DELETE from " + this.tableName;
this.dropTableSql = "DROP TABLE " + this.tableName;
factory = new SimpleConnectionFactory(connectionUrl, username, password);
factory.start(driverClass);
}
public TableManipulation(String driverClass, DBServer DBServer, String tableName, String idColumnName, String dataColumnName) {
this(driverClass, DBServer.connectionUrl, DBServer.username, DBServer.password, tableName, idColumnName, dataColumnName);
}
private String getIdentifierQuoteString() {
if (identifierQuoteString == null) {
if (connectionUrl.contains("mysql"))
identifierQuoteString = "`";
else
identifierQuoteString = "\"";
}
return identifierQuoteString;
}
public Object getValueByKeyAwait(String key) throws Exception {
final Connection connection = factory.getConnection();
final PreparedStatement ps = connection.prepareStatement(getRowByKeySql);
ps.setString(1, key);
Object toReturn = null;
try {
toReturn = withAwait(() -> {
ResultSet rs;
rs = ps.executeQuery();
Object result = null;
if (rs.next()) {
result = rs.getObject(dataColumnName); //start from 1, not 0
}
return result;
});
} finally {
factory.releaseConnection(connection);
}
return toReturn;
}
public Object getValueByKey(String key) throws Exception {
Connection connection = factory.getConnection();
Object result = null;
try {
PreparedStatement ps = connection.prepareStatement(getRowByKeySql);
ps.setString(1, key);
ResultSet rs = ps.executeQuery();
if (rs.next()) {
result = rs.getObject(dataColumnName); //start from 1, not 0
}
} finally {
factory.releaseConnection(connection);
}
return result;
}
public List<String> getAllRows() throws Exception {
Connection connection = factory.getConnection();
final Statement s = connection.createStatement();
ResultSet rs;
List<String> rows = new ArrayList<String>();
try {
rs = s.executeQuery(getAllRowsSql);
while (rs.next()) {
rows.add(rs.toString());
}
} finally {
factory.releaseConnection(connection);
}
return rows;
}
public List<String> getAllKeys() throws Exception {
Connection connection = factory.getConnection();
Statement s = connection.createStatement();
List<String> keys = new ArrayList<String>();
try {
ResultSet rs = s.executeQuery(getAllRowsSql);
while (rs.next()) {
keys.add(rs.getObject(idColumnName).toString());
}
} finally {
factory.releaseConnection(connection);
}
return keys;
}
public void deleteAllRows() throws Exception {
Connection connection = factory.getConnection();
final Statement s = connection.createStatement();
try {
s.executeUpdate(deleteAllRowsSql);
} finally {
factory.releaseConnection(connection);
}
}
private List<String> getTableNames() throws Exception {
List<String> tables = new ArrayList<String>();
Connection connection = factory.getConnection();
try {
DatabaseMetaData dbm = connection.getMetaData();
String[] types = {"TABLE"};
ResultSet rs = dbm.getTables(null, null, "%", types);
while (rs.next()) {
String table = rs.getString("TABLE_NAME");
tables.add(table);
}
} finally {
factory.releaseConnection(connection);
}
return tables;
}
public void dropTable() throws Exception {
Connection connection = factory.getConnection();
final Statement s = connection.createStatement();
try {
s.executeUpdate(dropTableSql);
} finally {
factory.releaseConnection(connection);
}
}
public boolean exists() throws Exception {
List<String> tables = getTableNames();
return tables.contains(tableName.substring(1, tableName.length() - 1));
}
public String getConnectionUrl() {
return connectionUrl;
}
public String getUsername() {
return username;
}
public String getPassword() {
return password;
}
private <T> T withAwait(Callable<T> c) {
T result = null;
final long timeout = System.currentTimeMillis() + TIMEOUT;
while (result == null && System.currentTimeMillis() < timeout) {
try {
result = c.call();
} catch (Exception e) {
sleepForSecs(RETRY_TIME);
}
}
return result;
}
}
}