package liquibase.sqlgenerator.core;
import java.util.Arrays;
import liquibase.database.Database;
import liquibase.database.typeconversion.TypeConverterFactory;
import liquibase.exception.LiquibaseException;
import liquibase.exception.ValidationErrors;
import liquibase.sqlgenerator.SqlGeneratorChain;
import liquibase.statement.core.InsertOrUpdateStatement;
import liquibase.statement.core.UpdateStatement;
import liquibase.sql.Sql;
import liquibase.sql.UnparsedSql;
import java.util.Date;
import java.util.HashSet;
public abstract class InsertOrUpdateGenerator extends AbstractSqlGenerator<InsertOrUpdateStatement> {
protected abstract String getRecordCheck(InsertOrUpdateStatement insertOrUpdateStatement, Database database, String whereClause);
protected abstract String getElse(Database database);
protected String getPostUpdateStatements(){
return "";
}
@Override
public int getPriority() {
return PRIORITY_DATABASE;
}
public ValidationErrors validate(InsertOrUpdateStatement statement, Database database, SqlGeneratorChain sqlGeneratorChain) {
ValidationErrors validationErrors = new ValidationErrors();
validationErrors.checkRequiredField("tableName", statement.getTableName());
validationErrors.checkRequiredField("columns", statement.getColumnValues());
validationErrors.checkRequiredField("primaryKey", statement.getPrimaryKey());
return validationErrors;
}
protected String getWhereClause(InsertOrUpdateStatement insertOrUpdateStatement, Database database) {
StringBuffer where = new StringBuffer();
String[] pkColumns = insertOrUpdateStatement.getPrimaryKey().split(",");
for(String thisPkColumn:pkColumns)
{
where.append(database.escapeColumnName(insertOrUpdateStatement.getSchemaName(), insertOrUpdateStatement.getTableName(), thisPkColumn)).append(" = ");
Object newValue = insertOrUpdateStatement.getColumnValues().get(thisPkColumn);
if (newValue == null || newValue.toString().equals("NULL")) {
where.append("NULL");
} else if (newValue instanceof String && database.shouldQuoteValue(((String) newValue))) {
where.append("'").append(database.escapeStringForDatabase((String) newValue)).append("'");
} else if (newValue instanceof Date) {
where.append(database.getDateLiteral(((Date) newValue)));
} else if (newValue instanceof Boolean) {
if (((Boolean) newValue)) {
where.append(TypeConverterFactory.getInstance().findTypeConverter(database).getBooleanType().getTrueBooleanValue());
} else {
where.append(TypeConverterFactory.getInstance().findTypeConverter(database).getBooleanType().getFalseBooleanValue());
}
} else {
where.append(newValue);
}
where.append(" AND ");
}
where.delete(where.lastIndexOf(" AND "),where.lastIndexOf(" AND ") + " AND ".length());
return where.toString();
}
protected String getInsertStatement(InsertOrUpdateStatement insertOrUpdateStatement, Database database, SqlGeneratorChain sqlGeneratorChain) {
StringBuffer insertBuffer = new StringBuffer();
InsertGenerator insert = new InsertGenerator();
Sql[] insertSql = insert.generateSql(insertOrUpdateStatement,database,sqlGeneratorChain);
for(Sql s:insertSql)
{
insertBuffer.append(s.toSql());
insertBuffer.append(";");
}
insertBuffer.append("\n");
return insertBuffer.toString();
}
/**
*
* @param insertOrUpdateStatement
* @param database
* @param whereClause
* @param sqlGeneratorChain
* @return the update statement, if there is nothing to update return null
*/
protected String getUpdateStatement(InsertOrUpdateStatement insertOrUpdateStatement,Database database, String whereClause, SqlGeneratorChain sqlGeneratorChain) throws LiquibaseException {
StringBuffer updateSqlString = new StringBuffer();
UpdateGenerator update = new UpdateGenerator();
UpdateStatement updateStatement = new UpdateStatement(insertOrUpdateStatement.getSchemaName(),insertOrUpdateStatement.getTableName());
updateStatement.setWhereClause(whereClause + ";\n");
String[] pkFields=insertOrUpdateStatement.getPrimaryKey().split(",");
HashSet<String> hashPkFields = new HashSet<String>(Arrays.asList(pkFields));
for(String columnKey:insertOrUpdateStatement.getColumnValues().keySet())
{
if (!hashPkFields.contains(columnKey)) {
updateStatement.addNewColumnValue(columnKey,insertOrUpdateStatement.getColumnValue(columnKey));
}
}
// this isn't very elegant but the code fails above without any columns to update
if(updateStatement.getNewColumnValues().isEmpty()) {
throw new LiquibaseException("No fields to update in set clause");
}
Sql[] updateSql = update.generateSql(updateStatement, database, sqlGeneratorChain);
for(Sql s:updateSql)
{
updateSqlString.append(s.toSql());
updateSqlString.append(";");
}
updateSqlString.deleteCharAt(updateSqlString.lastIndexOf(";"));
updateSqlString.append("\n");
return updateSqlString.toString();
}
public Sql[] generateSql(InsertOrUpdateStatement insertOrUpdateStatement, Database database, SqlGeneratorChain sqlGeneratorChain) {
StringBuffer completeSql = new StringBuffer();
String whereClause = getWhereClause(insertOrUpdateStatement, database);
completeSql.append( getRecordCheck(insertOrUpdateStatement, database, whereClause));
completeSql.append(getInsertStatement(insertOrUpdateStatement, database, sqlGeneratorChain));
try {
String updateStatement = getUpdateStatement(insertOrUpdateStatement,database,whereClause,sqlGeneratorChain);
completeSql.append(getElse(database));
completeSql.append(updateStatement);
} catch (LiquibaseException e) {}
completeSql.append(getPostUpdateStatements());
return new Sql[]{
new UnparsedSql(completeSql.toString())
};
}
}