/**
* Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.util.db.management;
import static com.opengamma.util.RegexUtils.matches;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.commons.lang.ObjectUtils;
import org.hibernate.id.enhanced.SequenceStructure;
import org.hibernate.mapping.ForeignKey;
import org.hibernate.mapping.Table;
import com.opengamma.OpenGammaRuntimeException;
import com.opengamma.util.tuple.FirstThenSecondPairComparator;
import com.opengamma.util.tuple.Pair;
import com.opengamma.util.tuple.Pairs;
/**
* Abstract implementation of database management.
*/
public abstract class AbstractDbManagement implements DbManagement {
/**
* The schema version table suffix.
*/
protected static final String SCHEMA_VERSION_TABLE_SUFFIX = "_schema_version";
/**
* The database server.
*/
private String _dbServerHost;
/**
* The user name.
*/
private String _user;
/**
* The password.
*/
private String _password;
//-------------------------------------------------------------------------
@Override
public void initialise(String dbServerHost, String user, String password) {
_dbServerHost = dbServerHost;
_user = user;
_password = password;
try {
getJDBCDriverClass().newInstance(); // load the driver
} catch (Exception ex) {
throw new OpenGammaRuntimeException("Cannot load JDBC driver", ex);
}
}
@Override
public String getTestCatalog() {
return "test_" + System.getProperty("user.name").replace('.', '_');
}
@Override
public String getTestSchema() {
return null; // use default
}
@Override
public void reset(String catalog) {
// by default, do nothing
}
@Override
public void shutdown(String catalog) {
// by default, do nothing
}
public String getDbHost() {
return _dbServerHost;
}
public String getUser() {
return _user;
}
public String getPassword() {
return _password;
}
//-------------------------------------------------------------------------
/**
* Generic representation of a column.
*/
protected class ColumnDefinition implements Comparable<ColumnDefinition> {
private final String _name;
private final String _dataType;
private final String _defaultValue;
private final String _allowsNull;
protected ColumnDefinition(final String name, final String dataType, final String defaultValue, final String allowsNull) {
_name = name;
_dataType = dataType;
_defaultValue = defaultValue;
_allowsNull = allowsNull;
}
public String getName() {
return _name;
}
public String getDataType() {
return _dataType;
}
public String getDefaultValue() {
return _defaultValue;
}
public String getAllowsNull() {
return _allowsNull;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(getName().toUpperCase()).append('=').append(getDataType().toUpperCase());
if (getAllowsNull() != null) {
sb.append(";NULL=").append(getAllowsNull());
}
if (getDefaultValue() != null) {
sb.append(";DEFAULT=").append(getDefaultValue());
}
return sb.toString();
}
@Override
public boolean equals(final Object obj) {
if (obj == this) {
return true;
}
if (obj instanceof ColumnDefinition) {
ColumnDefinition c = (ColumnDefinition) obj;
return ObjectUtils.equals(getName(), c.getName())
&& ObjectUtils.equals(getDataType(), c.getDataType())
&& ObjectUtils.equals(getAllowsNull(), c.getAllowsNull())
&& ObjectUtils.equals(getDefaultValue(), c.getDefaultValue());
}
return false;
}
@Override
public int hashCode() {
int hc = 1;
hc = hc * 17 + ObjectUtils.hashCode(getName());
hc = hc * 17 + ObjectUtils.hashCode(getDataType());
hc = hc * 17 + ObjectUtils.hashCode(getAllowsNull());
hc = hc * 17 + ObjectUtils.hashCode(getDefaultValue());
return hc;
}
@Override
public int compareTo(final ColumnDefinition c) {
return getName().compareTo(c.getName());
}
}
public abstract String getAllSchemasSQL(String catalog);
public abstract String getAllTablesSQL(String catalog, String schema);
public abstract String getAllViewsSQL(String catalog, String schema);
public abstract String getAllColumnsSQL(String catalog, String schema, String table);
public abstract String getAllSequencesSQL(String catalog, String schema);
public abstract String getAllForeignKeyConstraintsSQL(String catalog, String schema);
public abstract String getCreateSchemaSQL(String catalog, String schema);
public abstract String getSchemaVersionTable(String schemaGroupName);
public abstract String getSchemaVersionSQL(String catalog, String schemaGroupName);
public abstract CatalogCreationStrategy getCatalogCreationStrategy();
public void setActiveSchema(Connection connection, String schema) throws SQLException {
// override in subclasses as necessary
}
protected Connection connect(String catalog) throws SQLException {
Connection conn = DriverManager.getConnection(getCatalogToConnectTo(catalog), _user, _password);
conn.setAutoCommit(true);
return conn;
}
@Override
public String getCatalogToConnectTo(String catalog) {
return getDbHost() + "/" + catalog;
}
protected List<String> getAllTables(String catalog, String schema, Statement statement) throws SQLException {
List<String> tables = new LinkedList<String>();
try (ResultSet rs = statement.executeQuery(getAllTablesSQL(catalog, schema))) {
while (rs.next()) {
tables.add(rs.getString("name"));
}
}
return tables;
}
protected List<String> getAllViews(String catalog, String schema, Statement statement) throws SQLException {
List<String> tables = new LinkedList<String>();
try (ResultSet rs = statement.executeQuery(getAllViewsSQL(catalog, schema))) {
while (rs.next()) {
tables.add(rs.getString("name"));
}
}
return tables;
}
private List<ColumnDefinition> getAllColumns(String catalog, String schema, String table, Statement statement) throws SQLException {
List<ColumnDefinition> columns = new LinkedList<ColumnDefinition>();
try (ResultSet rs = statement.executeQuery(getAllColumnsSQL(catalog, schema, table))) {
while (rs.next()) {
columns.add(new ColumnDefinition(rs.getString("name"), rs.getString("datatype"), rs.getString("defaultvalue"), rs.getString("allowsnull")));
}
}
return columns;
}
@Override
public void clearTables(String catalog, String schema, Collection<String> ignoredTables) {
LinkedList<String> script = new LinkedList<String>();
try {
if (!getCatalogCreationStrategy().catalogExists(catalog)) {
return; // nothing to clear
}
try (Connection conn = connect(catalog)) {
setActiveSchema(conn, schema);
try (Statement statement = conn.createStatement()) {
// Clear tables SQL
List<String> tablesToClear = new ArrayList<String>();
for (String name : getAllTables(catalog, schema, statement)) {
if (!ignoredTables.contains(name.toLowerCase())) {
tablesToClear.add(name);
}
}
List<String> clearTablesCommands = getClearTablesCommand(schema, tablesToClear);
script.addAll(clearTablesCommands);
for (String name : tablesToClear) {
Table table = new Table(name);
if (matches(table.getName().toLowerCase(), Pattern.compile(".*?hibernate_sequence"))) { // if it's a sequence table, reset it
script.add("INSERT INTO " + table.getQualifiedName(getHibernateDialect(), null, schema) + " values ( 1 )");
}
}
// Now execute it all. Constraints are taken into account by retrying the failed statement after all
// dependent tables have been cleared first.
int i = 0;
int maxAttempts = script.size() * 3; // make sure the loop eventually terminates. Important if there's a cycle in the table dependency graph
SQLException latestException = null;
while (i < maxAttempts && !script.isEmpty()) {
String sql = script.remove();
try {
statement.executeUpdate(sql);
} catch (SQLException e) {
// assume it failed because of a constraint violation
// try deleting other tables first - make this the new last statement
latestException = e;
script.add(sql);
}
i++;
}
if (i == maxAttempts && !script.isEmpty()) {
throw new OpenGammaRuntimeException("Failed to clear tables - is there a cycle in the table dependency graph?", latestException);
}
}
}
} catch (SQLException e) {
throw new OpenGammaRuntimeException("Failed to clear tables", e);
}
}
protected List<String> getClearTablesCommand(String schema, List<String> tablesToClear) {
List<String> clearTablesCommands = new ArrayList<String>();
for (String name : tablesToClear) {
Table table = new Table(name);
clearTablesCommands.add("DELETE FROM " + table.getQualifiedName(getHibernateDialect(), null, schema));
}
return clearTablesCommands;
}
protected List<String> getAllSchemas(final String catalog, final Statement stmt) throws SQLException {
final List<String> schemas = new LinkedList<String>();
try (ResultSet rs = stmt.executeQuery(getAllSchemasSQL(catalog))) {
while (rs.next()) {
schemas.add(rs.getString("name"));
}
}
return schemas;
}
@Override
public void createSchema(String catalog, String schema) {
try {
getCatalogCreationStrategy().create(catalog);
if (schema != null) {
// Connect to the new catalog and create the schema
try (Connection conn = connect(catalog)) {
try (Statement statement = conn.createStatement()) {
Collection<String> schemas = getAllSchemas(catalog, statement);
if (!schemas.contains(schema)) {
String createSchemaSql = getCreateSchemaSQL(catalog, schema);
statement.executeUpdate(createSchemaSql);
}
}
}
}
} catch (SQLException e) {
throw new OpenGammaRuntimeException("Failed to clear tables", e);
}
}
protected List<String> getAllSequences(final String catalog, final String schema, final Statement stmt) throws SQLException {
final List<String> sequences = new LinkedList<String>();
final String sql = getAllSequencesSQL(catalog, schema);
if (sql != null) {
try (ResultSet rs = stmt.executeQuery(sql)) {
while (rs.next()) {
sequences.add(rs.getString("name"));
}
}
}
return sequences;
}
protected List<Pair<String, String>> getAllForeignKeyConstraints(final String catalog, final String schema, final Statement stmt) throws SQLException {
final List<Pair<String, String>> sequences = new LinkedList<Pair<String, String>>();
final String sql = getAllForeignKeyConstraintsSQL(catalog, schema);
if (sql != null) {
try (ResultSet rs = stmt.executeQuery(sql)) {
while (rs.next()) {
sequences.add(Pairs.of(rs.getString("name"), rs.getString("table_name")));
}
}
}
return sequences;
}
@Override
public void dropSchema(String catalog, String schema) {
// Does not handle triggers or stored procedures yet
ArrayList<String> script = new ArrayList<String>();
try {
if (!getCatalogCreationStrategy().catalogExists(catalog)) {
System.out.println("Catalog " + catalog + " does not exist");
return; // nothing to drop
}
try (Connection conn = connect(catalog)) {
if (schema != null) {
try (Statement statement = conn.createStatement()) {
Collection<String> schemas = getAllSchemas(catalog, statement);
if (!schemas.contains(schema)) {
System.out.println("Schema " + schema + " does not exist");
return; // nothing to drop
}
}
}
setActiveSchema(conn, schema);
try (Statement statement = conn.createStatement()) {
// Drop constraints SQL
if (getHibernateDialect().dropConstraints()) {
for (Pair<String, String> constraint : getAllForeignKeyConstraints(catalog, schema, statement)) {
String name = constraint.getFirst();
String table = constraint.getSecond();
ForeignKey fk = new ForeignKey();
fk.setName(name);
fk.setTable(new Table(table));
String dropConstraintSql = fk.sqlDropString(getHibernateDialect(), null, schema);
script.add(dropConstraintSql);
}
}
// Drop views SQL
for (String name : getAllViews(catalog, schema, statement)) {
Table table = new Table(name);
String dropViewStr = table.sqlDropString(getHibernateDialect(), null, schema);
dropViewStr = dropViewStr.replaceAll("drop table", "drop view");
script.add(dropViewStr);
}
// Drop tables SQL
for (String name : getAllTables(catalog, schema, statement)) {
Table table = new Table(name);
String dropTableStr = table.sqlDropString(getHibernateDialect(), null, schema);
script.add(dropTableStr);
}
}
// Now execute it all
try (Statement statement = conn.createStatement()) {
for (String sql : script) {
//System.out.println("Executing \"" + sql + "\"");
statement.executeUpdate(sql);
}
}
// Drop sequences SQL
try (Statement statement = conn.createStatement()) {
script.clear();
for (String name : getAllSequences(catalog, schema, statement)) {
final SequenceStructure sequenceStructure = new SequenceStructure(getHibernateDialect(), name, 0, 1, Long.class);
String[] dropSequenceStrings = sequenceStructure.sqlDropStrings(getHibernateDialect());
script.addAll(Arrays.asList(dropSequenceStrings));
}
}
//now execute drop sequence
try (Statement statement = conn.createStatement()) {
for (String sql : script) {
//System.out.println("Executing \"" + sql + "\"");
statement.executeUpdate(sql);
}
}
}
} catch (SQLException e) {
throw new OpenGammaRuntimeException("Failed to drop schema", e);
}
}
@Override
public void executeSql(String catalog, String schema, String sql) {
ArrayList<String> sqlStatements = new ArrayList<String>();
boolean inDollarQuote = false;
boolean inComment = false;
StringBuilder stmtBuilder = new StringBuilder();
for (int currentIdx = 0; currentIdx < sql.length(); currentIdx++) {
char currentChar = sql.charAt(currentIdx);
char nextChar = currentIdx + 1 < sql.length() ? sql.charAt(currentIdx + 1) : 0;
if (inDollarQuote) {
// Add everything verbatim until the end-of-quote $$
if (currentChar == '$' && nextChar == '$') {
inDollarQuote = false;
}
stmtBuilder.append(currentChar);
continue;
}
boolean isLineEnd = currentChar == '\r' || currentChar == '\n';
if (currentChar == '\r' && nextChar == '\n') {
currentIdx++;
}
if (inComment) {
// Ignore everything until the next new line
if (isLineEnd) {
inComment = false;
}
continue;
}
if (isLineEnd) {
stmtBuilder.append(" ");
continue;
}
if (currentChar == ';') {
String currentStmt = stmtBuilder.toString().trim();
if (!currentStmt.isEmpty()) {
sqlStatements.add(currentStmt);
}
stmtBuilder = new StringBuilder();
continue;
}
if (currentChar == '-' && nextChar == '-') {
inComment = true;
continue;
}
if (currentChar == '$' && nextChar == '$') {
inDollarQuote = true;
}
stmtBuilder.append(currentChar);
}
String currentStmt = stmtBuilder.toString().trim();
if (!currentStmt.isEmpty()) {
sqlStatements.add(currentStmt);
}
try (Connection conn = connect(catalog)) {
setActiveSchema(conn, schema);
Statement statement = conn.createStatement();
for (String sqlStatement : sqlStatements) {
try {
statement.execute(sqlStatement);
} catch (SQLException e) {
throw new OpenGammaRuntimeException("Failed to execute statement (" + getDbHost() + ") " + sqlStatement, e);
}
}
statement.close();
} catch (SQLException e) {
throw new OpenGammaRuntimeException("Failed to execute statement", e);
}
}
@Override
public String describeDatabase(final String catalog) {
return describeDatabase(catalog, null);
}
@Override
public String describeDatabase(final String catalog, final String prefix) {
final StringBuilder description = new StringBuilder();
try (Connection conn = connect(catalog)) {
final Statement stmt = conn.createStatement();
final List<String> schemas = getAllSchemas(catalog, stmt);
Collections.sort(schemas);
if (schemas.size() == 0) {
schemas.add(null);
}
for (String schema : schemas) {
description.append("schema: ").append(schema).append("\r\n");
final List<String> tables = getAllTables(catalog, schema, stmt);
Collections.sort(tables);
for (String table : tables) {
description.append("table: ").append(table).append("\r\n");
final List<ColumnDefinition> columns = getAllColumns(catalog, schema, table, stmt);
Collections.sort(columns);
for (ColumnDefinition column : columns) {
description.append("column: ").append(column).append("\r\n");
}
}
final List<String> sequences = getAllSequences(catalog, schema, stmt);
Collections.sort(sequences);
for (String sequence : sequences) {
description.append("sequence: ").append(sequence).append("\r\n");
}
final List<Pair<String, String>> foreignKeys = getAllForeignKeyConstraints(catalog, schema, stmt);
Collections.sort(foreignKeys, FirstThenSecondPairComparator.INSTANCE);
for (Pair<String, String> foreignKey : foreignKeys) {
description.append("foreign key: ").append(foreignKey.getFirst()).append('.').append(foreignKey.getSecond()).append("\r\n");
}
}
} catch (SQLException e) {
e.printStackTrace();
System.err.println("e.getMessage: " + e.getMessage());
throw new OpenGammaRuntimeException("SQL exception", e);
}
return description.toString();
}
@Override
public List<String> listTables(final String catalog) {
try (Connection conn = connect(catalog)) {
final Statement stmt = conn.createStatement();
return getAllTables(catalog, null, stmt);
} catch (SQLException e) {
e.printStackTrace();
System.err.println("e.getMessage: " + e.getMessage());
throw new OpenGammaRuntimeException("SQL exception", e);
}
}
@Override
public Integer getSchemaGroupVersion(String catalog, String schema, String schemaGroupName) {
try (Connection conn = connect(catalog)) {
setActiveSchema(conn, schema);
Statement statement = conn.createStatement();
List<String> tables = getAllTables(catalog, schema, statement);
if (!tables.contains(getSchemaVersionTable(schemaGroupName))) {
return null;
}
String version;
try (ResultSet rs = statement.executeQuery(getSchemaVersionSQL(catalog, schemaGroupName))) {
rs.next();
version = rs.getString("version_value");
if (rs.next()) {
throw new OpenGammaRuntimeException("Expected one schema version entry for group " + schemaGroupName + " but found multiple");
}
}
return Integer.parseInt(version);
} catch (SQLException e) {
e.printStackTrace();
System.err.println("e.getMessage: " + e.getMessage());
throw new OpenGammaRuntimeException("SQL exception", e);
}
}
}