package com.w11k.lsql.dialects; import com.google.common.base.Function; import com.google.common.base.Joiner; import com.google.common.collect.Lists; import com.w11k.lsql.Column; import com.w11k.lsql.LSql; import com.w11k.lsql.Table; import com.w11k.lsql.exceptions.DatabaseAccessException; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Statement; import java.util.Collection; import java.util.Collections; import java.util.List; import static com.w11k.lsql.jdbc.ConnectionUtils.getConnection; public class StatementCreator { public Statement createStatement(LSql lSql) { try { return getConnection(lSql).createStatement(); } catch (SQLException e) { throw new DatabaseAccessException(e); } } public PreparedStatement createPreparedStatement(LSql lSql, String sqlString, boolean returnAutoGeneratedKeys) { try { if (returnAutoGeneratedKeys) { return getConnection(lSql).prepareStatement(sqlString, Statement.RETURN_GENERATED_KEYS); } else { return getConnection(lSql).prepareStatement(sqlString); } } catch (SQLException e) { throw new DatabaseAccessException(e); } } public PreparedStatement createRevisionQueryStatement(Table table) { String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName()); String revCol = getRevisionColumnSqlIdentifier(table); String sql = "SELECT " + revCol + " FROM " + sqlTableName + " WHERE "; sql += table.getPrimaryKeyColumn().get(); sql += "=?"; return createPreparedStatement(table.getlSql(), sql, false); } public PreparedStatement createInsertStatement(final Table table, List<String> columns) { String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName()); String sql = ""; sql += "INSERT INTO " + sqlTableName; sql += "("; sql += Joiner.on(",").join(createSqlColumnNames(table, columns)); sql += ")VALUES("; sql += Joiner.on(",").join(Collections.nCopies(columns.size(), "?")); sql += ");"; return createPreparedStatement(table.getlSql(), sql, true); } public PreparedStatement createUpdateStatement(Table table, List<String> columns, List<String> whereColumns) { String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName()); String sql = "UPDATE " + sqlTableName; sql += " SET "; // set revision value? if (table.getRevisionColumn().isPresent()) { String revCol = getRevisionColumnSqlIdentifier(table); sql += revCol + "=" + revCol + "+1"; if (columns.size() > 0) { sql += ","; } } // new values sql += Joiner.on(",").join(Lists.transform( createSqlColumnNames(table, columns), new Function<String, Object>() { @Override public Object apply(String input) { return input + "=?"; } })); // where sql += " WHERE "; sql += Joiner.on(" AND ").join(Lists.transform( createSqlColumnNames(table, whereColumns), new Function<String, Object>() { @Override public Object apply(String input) { return input + "=?"; } })); if (table.getRevisionColumn().isPresent()) { sql += " AND "; sql += getRevisionColumnSqlIdentifier(table) + "=?"; } sql += ";"; return createPreparedStatement(table.getlSql(), sql, false); } public String createSelectByIdStatement(Table table, Column idColumn, Collection<Column> columns) { String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName()); String sqlColumnName = idColumn.getTable().getlSql().identifierJavaToSql(idColumn.getJavaColumnName()); String sql = "SELECT "; for (Column column : columns) { if (column.isIgnored()) { continue; } sql += column.getSqlColumnName(); sql += ","; } sql = sql.substring(0, sql.length() - 1); sql += " FROM " + sqlTableName + " WHERE " + sqlColumnName + "=?;"; return sql; } public PreparedStatement createDeleteByIdStatement(Table table) { Column idColumn = table.column(table.getPrimaryKeyColumn().get()); String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName()); String sqlIdName = idColumn.getTable().getlSql().identifierJavaToSql(idColumn.getJavaColumnName()); String sql = "DELETE FROM "; sql += sqlTableName; sql += " WHERE "; sql += sqlIdName + "=?"; if (table.getRevisionColumn().isPresent()) { String sqlRevisionName = getRevisionColumnSqlIdentifier(table); sql += " AND " + sqlRevisionName + "=?"; } sql += ";"; return createPreparedStatement(table.getlSql(), sql, false); } public PreparedStatement createCountForIdStatement(Table table) throws SQLException { Column idColumn = table.column(table.getPrimaryKeyColumn().get()); String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName()); String sqlColumnName = idColumn.getTable().getlSql().identifierJavaToSql(idColumn.getJavaColumnName()); String sql = "select count(" + sqlColumnName + ") from " + sqlTableName + " where " + sqlColumnName + "=?"; return createPreparedStatement(table.getlSql(), sql, false); } private String getRevisionColumnSqlIdentifier(Table table) { return table.getlSql().identifierJavaToSql(table.getRevisionColumn().get().getJavaColumnName()); } private List<String> createSqlColumnNames(final Table table, List<String> columns) { return Lists.transform(columns, new Function<String, String>() { @Override public String apply(String input) { return table.getlSql().identifierJavaToSql(input); } }); } }