package liquibase.sqlgenerator.core; import liquibase.configuration.GlobalConfiguration; import liquibase.configuration.LiquibaseConfiguration; import liquibase.database.Database; import liquibase.database.core.DB2Database; import liquibase.database.core.MSSQLDatabase; import liquibase.database.core.OracleDatabase; import liquibase.exception.ValidationErrors; import liquibase.parser.ChangeLogParserCofiguration; import liquibase.sql.Sql; import liquibase.sql.UnparsedSql; import liquibase.sqlgenerator.SqlGeneratorChain; import liquibase.statement.core.CreateProcedureStatement; import liquibase.structure.core.Schema; import liquibase.structure.core.StoredProcedure; import liquibase.util.SqlParser; import liquibase.util.StringClauses; import liquibase.util.StringUtils; import java.util.ArrayList; import java.util.List; import java.util.regex.Pattern; public class CreateProcedureGenerator extends AbstractSqlGenerator<CreateProcedureStatement> { @Override public ValidationErrors validate(CreateProcedureStatement statement, Database database, SqlGeneratorChain sqlGeneratorChain) { ValidationErrors validationErrors = new ValidationErrors(); validationErrors.checkRequiredField("procedureText", statement.getProcedureText()); if (statement.getReplaceIfExists() != null) { if (database instanceof MSSQLDatabase) { if (statement.getReplaceIfExists() && statement.getProcedureName() == null) { validationErrors.addError("procedureName is required if replaceIfExists = true"); } } else { validationErrors.checkDisallowedField("replaceIfExists", statement.getReplaceIfExists(), null); } } return validationErrors; } @Override public Sql[] generateSql(CreateProcedureStatement statement, Database database, SqlGeneratorChain sqlGeneratorChain) { List<Sql> sql = new ArrayList<Sql>(); String schemaName = statement.getSchemaName(); if (schemaName == null && LiquibaseConfiguration.getInstance().getConfiguration(GlobalConfiguration.class).getAlwaysOverrideStoredLogicSchema()) { schemaName = database.getDefaultSchemaName(); } String procedureText = addSchemaToText(statement.getProcedureText(), schemaName, "PROCEDURE", database); if (statement.getReplaceIfExists() != null && statement.getReplaceIfExists()) { String fullyQualifiedName = database.escapeObjectName(statement.getProcedureName(), StoredProcedure.class); if (schemaName != null) { fullyQualifiedName = database.escapeObjectName(schemaName, Schema.class) + "." + fullyQualifiedName; } sql.add(new UnparsedSql("if object_id('" + fullyQualifiedName + "', 'p') is null exec ('create procedure " + fullyQualifiedName + " as select 1 a')")); StringClauses parsedSql = SqlParser.parse(procedureText, true, true); StringClauses.ClauseIterator clauseIterator = parsedSql.getClauseIterator(); Object next = "START"; while (next != null && !(next.toString().equalsIgnoreCase("create") || next.toString().equalsIgnoreCase("alter")) && clauseIterator.hasNext()) { next = clauseIterator.nextNonWhitespace(); } clauseIterator.replace("ALTER"); procedureText = parsedSql.toString(); } procedureText = removeTrailingDelimiter(procedureText, statement.getEndDelimiter()); if (database instanceof MSSQLDatabase && procedureText.toLowerCase().contains("merge") && !procedureText.endsWith(";")) { //mssql "AS MERGE" procedures need a trailing ; (regardless of the end delimiter) StringClauses parsed = SqlParser.parse(procedureText); StringClauses.ClauseIterator clauseIterator = parsed.getClauseIterator(); boolean reallyMerge = false; while (clauseIterator.hasNext()) { Object clause = clauseIterator.nextNonWhitespace(); if (((String) clause).equalsIgnoreCase("merge")) { reallyMerge = true; } } if (reallyMerge) { procedureText = procedureText + ";"; } } sql.add(new UnparsedSql(procedureText, statement.getEndDelimiter())); surroundWithSchemaSets(sql, statement.getSchemaName(), database); return sql.toArray(new Sql[sql.size()]); } public static String removeTrailingDelimiter(String procedureText, String endDelimiter) { if (procedureText == null) { return null; } if (endDelimiter == null) { return procedureText; } String fixedText = procedureText; while (fixedText.length() > 0) { String lastChar = fixedText.substring(fixedText.length() - 1); if (lastChar.equals(" ") || lastChar.equals("\n") || lastChar.equals("\r") || lastChar.equals("\t")) { fixedText = fixedText.substring(0, fixedText.length() - 1); } else { break; } } endDelimiter = endDelimiter.replace("\\r", "\r").replace("\\n", "\n"); if (fixedText.endsWith(endDelimiter)) { return fixedText.substring(0, fixedText.length() - endDelimiter.length()); } else { return procedureText; } } /** * Convenience method for when the schemaName is set but we don't want to parse the body */ public static void surroundWithSchemaSets(List<Sql> sql, String schemaName, Database database) { if ((StringUtils.trimToNull(schemaName) != null) && !LiquibaseConfiguration.getInstance().getProperty(ChangeLogParserCofiguration.class, ChangeLogParserCofiguration.USE_PROCEDURE_SCHEMA).getValue(Boolean.class)) { String defaultSchema = database.getDefaultSchemaName(); if (database instanceof OracleDatabase) { sql.add(0, new UnparsedSql("ALTER SESSION SET CURRENT_SCHEMA=" + database.escapeObjectName(schemaName, Schema.class))); sql.add(new UnparsedSql("ALTER SESSION SET CURRENT_SCHEMA=" + database.escapeObjectName(defaultSchema, Schema.class))); } else if (database instanceof DB2Database) { sql.add(0, new UnparsedSql("SET CURRENT SCHEMA " + schemaName)); sql.add(new UnparsedSql("SET CURRENT SCHEMA " + defaultSchema)); } } } /** * Convenience method for other classes similar to this that want to be able to modify the procedure text to add the schema */ public static String addSchemaToText(String procedureText, String schemaName, String keywordBeforeName, Database database) { if (schemaName == null) { return procedureText; } if ((StringUtils.trimToNull(schemaName) != null) && LiquibaseConfiguration.getInstance().getProperty(ChangeLogParserCofiguration.class, ChangeLogParserCofiguration.USE_PROCEDURE_SCHEMA).getValue(Boolean.class)) { StringClauses parsedSql = SqlParser.parse(procedureText, true, true); StringClauses.ClauseIterator clauseIterator = parsedSql.getClauseIterator(); Object next = "START"; while (next != null && !next.toString().equalsIgnoreCase(keywordBeforeName) && clauseIterator.hasNext()) { if (!keywordBeforeName.equalsIgnoreCase("PACKAGE") && ((String) next).equalsIgnoreCase("PACKAGE")) { return procedureText; } next = clauseIterator.nextNonWhitespace(); } if (next != null && clauseIterator.hasNext()) { Object procNameClause = clauseIterator.nextNonWhitespace(); if (procNameClause instanceof String) { String[] nameParts = ((String) procNameClause).split("\\."); String finalName; if (nameParts.length == 1) { finalName = database.escapeObjectName(schemaName, Schema.class) + "." + nameParts[0]; } else if (nameParts.length == 2) { finalName = database.escapeObjectName(schemaName, Schema.class) + "." + nameParts[1]; } else if (nameParts.length == 3) { finalName = nameParts[0] + "." + database.escapeObjectName(schemaName, Schema.class) + "." + nameParts[2]; } else { finalName = (String) procNameClause; //just go with what was there } clauseIterator.replace(finalName); } procedureText = parsedSql.toString(); } } return procedureText; } }