package com.w11k.lsql.dialects;
import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import com.w11k.lsql.Column;
import com.w11k.lsql.LSql;
import com.w11k.lsql.Table;
import com.w11k.lsql.exceptions.DatabaseAccessException;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import static com.w11k.lsql.jdbc.ConnectionUtils.getConnection;
public class StatementCreator {
public Statement createStatement(LSql lSql) {
try {
return getConnection(lSql).createStatement();
} catch (SQLException e) {
throw new DatabaseAccessException(e);
}
}
public PreparedStatement createPreparedStatement(LSql lSql, String sqlString, boolean returnAutoGeneratedKeys) {
try {
if (returnAutoGeneratedKeys) {
return getConnection(lSql).prepareStatement(sqlString, Statement.RETURN_GENERATED_KEYS);
} else {
return getConnection(lSql).prepareStatement(sqlString);
}
} catch (SQLException e) {
throw new DatabaseAccessException(e);
}
}
public PreparedStatement createRevisionQueryStatement(Table table) {
String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName());
String revCol = getRevisionColumnSqlIdentifier(table);
String sql = "SELECT " + revCol + " FROM " + sqlTableName + " WHERE ";
sql += table.getPrimaryKeyColumn().get();
sql += "=?";
return createPreparedStatement(table.getlSql(), sql, false);
}
public PreparedStatement createInsertStatement(final Table table, List<String> columns) {
String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName());
String sql = "";
sql += "INSERT INTO " + sqlTableName;
sql += "(";
sql += Joiner.on(",").join(createSqlColumnNames(table, columns));
sql += ")VALUES(";
sql += Joiner.on(",").join(Collections.nCopies(columns.size(), "?"));
sql += ");";
return createPreparedStatement(table.getlSql(), sql, true);
}
public PreparedStatement createUpdateStatement(Table table, List<String> columns, List<String> whereColumns) {
String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName());
String sql = "UPDATE " + sqlTableName;
sql += " SET ";
// set revision value?
if (table.getRevisionColumn().isPresent()) {
String revCol = getRevisionColumnSqlIdentifier(table);
sql += revCol + "=" + revCol + "+1";
if (columns.size() > 0) {
sql += ",";
}
}
// new values
sql += Joiner.on(",").join(Lists.transform(
createSqlColumnNames(table, columns),
new Function<String, Object>() {
@Override
public Object apply(String input) {
return input + "=?";
}
}));
// where
sql += " WHERE ";
sql += Joiner.on(" AND ").join(Lists.transform(
createSqlColumnNames(table, whereColumns),
new Function<String, Object>() {
@Override
public Object apply(String input) {
return input + "=?";
}
}));
if (table.getRevisionColumn().isPresent()) {
sql += " AND ";
sql += getRevisionColumnSqlIdentifier(table) + "=?";
}
sql += ";";
return createPreparedStatement(table.getlSql(), sql, false);
}
public String createSelectByIdStatement(Table table, Column idColumn, Collection<Column> columns) {
String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName());
String sqlColumnName = idColumn.getTable().getlSql().identifierJavaToSql(idColumn.getJavaColumnName());
String sql = "SELECT ";
for (Column column : columns) {
if (column.isIgnored()) {
continue;
}
sql += column.getSqlColumnName();
sql += ",";
}
sql = sql.substring(0, sql.length() - 1);
sql += " FROM " + sqlTableName + " WHERE " + sqlColumnName + "=?;";
return sql;
}
public PreparedStatement createDeleteByIdStatement(Table table) {
Column idColumn = table.column(table.getPrimaryKeyColumn().get());
String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName());
String sqlIdName = idColumn.getTable().getlSql().identifierJavaToSql(idColumn.getJavaColumnName());
String sql = "DELETE FROM ";
sql += sqlTableName;
sql += " WHERE ";
sql += sqlIdName + "=?";
if (table.getRevisionColumn().isPresent()) {
String sqlRevisionName = getRevisionColumnSqlIdentifier(table);
sql += " AND " + sqlRevisionName + "=?";
}
sql += ";";
return createPreparedStatement(table.getlSql(), sql, false);
}
public PreparedStatement createCountForIdStatement(Table table) throws SQLException {
Column idColumn = table.column(table.getPrimaryKeyColumn().get());
String sqlTableName = table.getlSql().identifierJavaToSql(table.getSchemaAndTableName());
String sqlColumnName = idColumn.getTable().getlSql().identifierJavaToSql(idColumn.getJavaColumnName());
String sql = "select count(" + sqlColumnName + ") from " + sqlTableName + " where " +
sqlColumnName + "=?";
return createPreparedStatement(table.getlSql(), sql, false);
}
private String getRevisionColumnSqlIdentifier(Table table) {
return table.getlSql().identifierJavaToSql(table.getRevisionColumn().get().getJavaColumnName());
}
private List<String> createSqlColumnNames(final Table table, List<String> columns) {
return Lists.transform(columns, new Function<String, String>() {
@Override
public String apply(String input) {
return table.getlSql().identifierJavaToSql(input);
}
});
}
}