package joist.jdbc;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import joist.util.Reflection;
public class Jdbc {
public static final String trackStatsKey = "joist.jdbc.trackStats";
private static final Logger log = LoggerFactory.getLogger(Jdbc.class);
private static final boolean trackStats = Boolean.valueOf(System.getProperty(trackStatsKey, "false"));
private static final AtomicInteger queries = new AtomicInteger(0);
private static final AtomicInteger updates = new AtomicInteger(0);
private Jdbc() {
}
/** @return the number of queries, only if {@link #trackStats} is enabled. */
public static int numberOfQueries() {
return queries.get();
}
/** @return the number of updates, only if {@link #trackStats} is enabled. */
public static int numberOfUpdates() {
return updates.get();
}
/** Resets the stats (only meaningful if {@link #trackStats} is enabled. */
public static void resetStats() {
updates.set(0);
queries.set(0);
}
public static int queryForInt(Connection connection, String sql, Object... args) {
int value = -1;
PreparedStatement stmt = null;
ResultSet rs = null;
try {
log.trace("sql = {}", sql);
tickQueriesIfTracking();
stmt = connection.prepareStatement(sql);
for (int i = 1; i <= args.length; i++) {
stmt.setObject(i, args[i - 1]);
}
rs = stmt.executeQuery();
if (rs.next()) {
value = rs.getInt(1);
}
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(rs, stmt);
}
return value;
}
public static int queryForInt(DataSource ds, String sql, Object... args) {
Connection connection = null;
try {
connection = ds.getConnection();
return Jdbc.queryForInt(connection, sql, args);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static Object[] queryForRow(Connection connection, String sql, Object... args) {
PreparedStatement stmt = null;
ResultSet rs = null;
try {
log.trace("sql = {}", sql);
tickQueriesIfTracking();
stmt = connection.prepareStatement(sql);
for (int i = 1; i <= args.length; i++) {
stmt.setObject(i, args[i - 1]);
}
rs = stmt.executeQuery();
int count = rs.getMetaData().getColumnCount();
Object[] objects = new Object[count];
if (rs.next()) {
for (int i = 0; i < count; i++) {
objects[i] = rs.getObject(i + 1);
}
}
return objects;
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(rs, stmt);
}
}
public static Object[] queryForRow(DataSource ds, String sql, Object... args) {
Connection connection = null;
try {
connection = ds.getConnection();
return Jdbc.queryForRow(connection, sql, args);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static int update(Connection connection, String sql, Object... args) {
PreparedStatement stmt = null;
try {
log.trace("sql = {}", sql);
tickUpdatesIfTracking();
stmt = connection.prepareStatement(sql);
for (int i = 1; i <= args.length; i++) {
stmt.setObject(i, args[i - 1]);
}
return stmt.executeUpdate();
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(stmt);
}
}
public static int update(DataSource ds, String sql, Object... args) {
Connection connection = null;
try {
connection = ds.getConnection();
return Jdbc.update(connection, sql, args);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static void query(DataSource ds, String sql, RowMapper rse) {
Connection connection = null;
try {
connection = ds.getConnection();
Jdbc.query(connection, sql, rse);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static void query(Connection connection, String sql, RowMapper rse) {
Statement s = null;
ResultSet rs = null;
try {
log.trace("sql = {}", sql);
tickQueriesIfTracking();
s = connection.createStatement();
rs = s.executeQuery(sql);
while (rs.next()) {
rse.mapRow(rs);
}
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(rs, s);
}
}
public static void query(DataSource ds, String sql, List<? extends Object> parameters, RowMapper rse) {
Connection connection = null;
try {
connection = ds.getConnection();
Jdbc.query(connection, sql, parameters, rse);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static void query(Connection connection, String sql, List<? extends Object> parameters, RowMapper rse) {
PreparedStatement s = null;
ResultSet rs = null;
try {
log.trace("sql = {}", sql);
log.trace("parameters = {}", parameters);
tickQueriesIfTracking();
s = connection.prepareStatement(sql);
for (int i = 0; i < parameters.size(); i++) {
s.setObject(i + 1, parameters.get(i));
}
rs = s.executeQuery();
while (rs.next()) {
rse.mapRow(rs);
}
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(rs, s);
}
}
public static <T> List<T> query(Connection connection, String sql, final Class<T> type) {
ReflectionRowMapper<T> mapper = new ReflectionRowMapper<T>(type);
Jdbc.query(connection, sql, mapper);
return mapper.results;
}
public static <T> List<T> query(DataSource ds, String sql, final Class<T> type) {
Connection connection = null;
try {
connection = ds.getConnection();
return Jdbc.query(connection, sql, type);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static <T> List<T> query(Connection connection, String sql, List<? extends Object> parameters, final Class<T> type) {
ReflectionRowMapper<T> mapper = new ReflectionRowMapper<T>(type);
Jdbc.query(connection, sql, parameters, mapper);
return mapper.results;
}
public static <T> List<T> query(DataSource ds, String sql, List<? extends Object> parameters, final Class<T> type) {
Connection connection = null;
try {
connection = ds.getConnection();
return Jdbc.query(connection, sql, parameters, type);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static int update(Connection connection, String sql, List<Object> parameters) {
PreparedStatement ps = null;
try {
log.trace("sql = {}", sql);
log.trace("parameters = {}", parameters);
tickUpdatesIfTracking();
ps = connection.prepareStatement(sql);
for (int i = 0; i < parameters.size(); i++) {
ps.setObject(i + 1, parameters.get(i));
}
return ps.executeUpdate();
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(ps);
}
}
public static Long[] insertBatch(Connection connection, String sql, List<List<Object>> allParameters) {
PreparedStatement ps = null;
try {
log.trace("sql = {}", sql);
tickUpdatesIfTracking();
ps = connection.prepareStatement(sql, new String[] { "id" });
for (List<Object> parameters : allParameters) {
log.trace("parameters = {}", parameters);
for (int i = 0; i < parameters.size(); i++) {
ps.setObject(i + 1, parameters.get(i));
}
ps.addBatch();
}
ps.executeBatch();
Long[] keys = new Long[allParameters.size()];
ResultSet ks = ps.getGeneratedKeys();
int i = 0;
while (ks.next()) {
keys[i++] = ks.getLong(1);
}
ks.close();
return keys;
} catch (SQLException se) {
// note that MySQL's BatchUpdateException.getNextException is null and so unreliable
throw new JdbcException(Jdbc.nextUntilNotNull(se));
} finally {
Jdbc.closeSafely(ps);
}
}
public static int update(DataSource ds, String sql, List<Object> parameters) {
Connection connection = null;
try {
connection = ds.getConnection();
return Jdbc.update(connection, sql, parameters);
} catch (SQLException se) {
throw new JdbcException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
public static List<Integer> updateBatch(Connection connection, String sql, List<List<Object>> allParameters) {
List<Integer> changed = new ArrayList<Integer>();
PreparedStatement ps = null;
try {
log.trace("sql = {}", sql);
tickUpdatesIfTracking();
ps = connection.prepareStatement(sql);
for (List<Object> parameters : allParameters) {
log.trace("parameters = {}", parameters);
for (int i = 0; i < parameters.size(); i++) {
ps.setObject(i + 1, parameters.get(i));
}
ps.addBatch();
}
int[] is = ps.executeBatch();
for (int i : is) {
changed.add(i);
}
return changed;
} catch (SQLException se) {
// note that MySQL's BatchUpdateException.getNextException is null and so unreliable
throw new JdbcException(Jdbc.nextUntilNotNull(se));
} finally {
Jdbc.closeSafely(ps);
}
}
public static SQLException nextUntilNotNull(SQLException current) {
while (current.getNextException() != null) {
current = current.getNextException();
}
return current;
}
public static void closeSafely(Connection conn) {
try {
if (conn != null) {
conn.close();
}
} catch (Exception e) {
log.error("Error occurred closing " + conn, e);
}
}
public static void closeSafely(PreparedStatement ps) {
try {
if (ps != null) {
ps.close();
}
} catch (Exception e) {
log.error("Error occurred closing " + ps, e);
}
}
public static void closeSafely(Statement stmt) {
try {
if (stmt != null) {
stmt.close();
}
} catch (Exception e) {
log.error("Error occurred closing " + stmt, e);
}
}
public static void closeSafely(ResultSet rs) {
try {
if (rs != null) {
rs.close();
}
} catch (Exception e) {
log.error("Error occurred closing " + rs, e);
}
}
public static void closeSafely(ResultSet rs, PreparedStatement ps) {
closeSafely(rs);
closeSafely(ps);
}
public static void closeSafely(ResultSet rs, Statement stmt) {
closeSafely(rs);
closeSafely(stmt);
}
public static void closeSafely(ResultSet rs, PreparedStatement ps, Connection conn) {
closeSafely(rs);
closeSafely(ps);
closeSafely(conn);
}
public static void closeSafely(ResultSet rs, Statement stmt, Connection conn) {
closeSafely(rs);
closeSafely(stmt);
closeSafely(conn);
}
public static <T> T inTransaction(DataSource ds, Function<Connection, T> function) {
Connection connection = null;
try {
connection = ds.getConnection();
connection.setAutoCommit(false);
T result = function.apply(connection);
connection.commit();
return result;
} catch (SQLException se) {
throw new RuntimeException(se);
} finally {
Jdbc.closeSafely(connection);
}
}
private static void tickQueriesIfTracking() {
if (trackStats) {
queries.incrementAndGet();
}
}
private static void tickUpdatesIfTracking() {
if (trackStats) {
updates.incrementAndGet();
}
}
private static final class ReflectionRowMapper<T> implements RowMapper {
private final List<T> results = new ArrayList<T>();
private final List<Field> fields = new ArrayList<Field>();
private final Class<T> type;
private boolean hadLoadedFields = false;
private ReflectionRowMapper(Class<T> type) {
this.type = type;
}
public void mapRow(ResultSet rs) throws SQLException {
T instance = Reflection.newInstance(this.type);
// only look up the Fields once
if (!this.hadLoadedFields) {
for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) {
this.fields.add(Reflection.getField(this.type, rs.getMetaData().getColumnLabel(i)));
}
this.hadLoadedFields = true;
}
// now we can do the actual set
for (int i = 0; i < this.fields.size(); i++) {
Reflection.set(this.fields.get(i), instance, getValueBasedOnType(rs, i, this.fields.get(i).getType()));
}
this.results.add(instance);
}
private static Object getValueBasedOnType(ResultSet rs, int i, Class<?> type) throws SQLException {
final Object value;
if (type.equals(Long.class)) {
value = rs.getLong(i + 1);
} else if (type.equals(Integer.class)) {
value = rs.getInt(i + 1);
} else if (type.equals(Boolean.class)) {
value = rs.getBoolean(i + 1);
} else if (type.equals(BigDecimal.class)) {
value = rs.getBigDecimal(i + 1);
} else if (type.equals(String.class)) {
value = rs.getString(i + 1);
} else if (type.equals(Byte.class)) {
value = rs.getByte(i + 1);
} else {
value = rs.getObject(i + 1);
}
return value;
}
}
}