/**
* Copyright 2013-2014 Guoqiang Chen, Shanghai, China. All rights reserved.
*
* Email: subchen@gmail.com
* URL: http://subchen.github.io/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jetbrick.dao.orm;
import java.sql.*;
import java.util.*;
import java.util.Date;
import javax.sql.DataSource;
import jetbrick.dao.DbException;
import jetbrick.dao.TransactionException;
import jetbrick.dao.dialect.SqlDialect;
import jetbrick.dao.orm.handlers.*;
import jetbrick.dao.orm.mappers.*;
import jetbrick.dao.orm.tx.*;
import jetbrick.dao.orm.utils.DbUtils;
import jetbrick.dao.orm.utils.PreparedStatementCreator;
import jetbrick.lang.Validate;
/**
* 数据库操作。单例使用
*/
@SuppressWarnings("unchecked")
public class DbHelper {
private static final boolean ALLOW_NESTED_TRANSACTION = System.getProperty("jetbrick.orm.transaction.nested.disabled") == null;
// 当前线程(事务)
private final ThreadLocal<JdbcTransaction> transationHandler = new ThreadLocal<JdbcTransaction>();
private final DataSource dataSource;
private final SqlDialect dialect;
public DbHelper(DataSource dataSource) {
this.dataSource = dataSource;
this.dialect = doGetDialet();
}
public DataSource getDataSource() {
return dataSource;
}
/**
* 启动一个事务(默认支持子事务)
*/
public Transaction transaction() {
if (transationHandler.get() != null) {
if (ALLOW_NESTED_TRANSACTION) {
return new JdbcNestedTransaction(transationHandler.get().getConnection());
}
throw new TransactionException("Can't begin a nested transaction.");
}
try {
JdbcTransaction tx = new JdbcTransaction(dataSource.getConnection(), transationHandler);
transationHandler.set(tx);
return tx;
} catch (SQLException e) {
throw new TransactionException(e);
}
}
/**
* 获取一个当前线程的连接(事务中),如果没有,则新建一个。
*/
private Connection getConnection() {
JdbcTransaction tx = transationHandler.get();
try {
if (tx == null) {
return dataSource.getConnection();
} else {
return tx.getConnection();
}
} catch (SQLException e) {
throw new DbException(e);
}
}
/**
* 释放一个连接,如果 Connection 不在事务中,则关闭它,否则不处理。
*/
private void closeConnection(Connection conn) {
if (transationHandler.get() == null) {
// not in transaction
DbUtils.closeQuietly(conn);
}
}
public <T> List<T> queryAsList(RowMapper<T> rowMapper, String sql, Object... parameters) {
Validate.notNull(rowMapper, "rowMapper is null.");
ResultSetHandler<List<T>> rsh = new RowListHandler<T>(rowMapper);
return query(rsh, sql, parameters);
}
public <T> List<T> queryAsList(Class<T> beanClass, String sql, Object... parameters) {
Validate.notNull(beanClass, "beanClass is null.");
RowMapper<T> rowMapper = getRowMapper(beanClass);
return queryAsList(rowMapper, sql, parameters);
}
public <T> T queryAsObject(RowMapper<T> rowMapper, String sql, Object... parameters) {
ResultSetHandler<T> rsh = new SingleRowHandler<T>(rowMapper);
return query(rsh, sql, parameters);
}
public <T> T queryAsObject(Class<T> beanClass, String sql, Object... parameters) {
Validate.notNull(beanClass, "beanClass is null.");
RowMapper<T> rowMapper = getRowMapper(beanClass);
return queryAsObject(rowMapper, sql, parameters);
}
public Integer queryAsInt(String sql, Object... parameters) {
return queryAsObject(Integer.class, sql, parameters);
}
public Long queryAsLong(String sql, Object... parameters) {
return queryAsObject(Long.class, sql, parameters);
}
public String queryAsString(String sql, Object... parameters) {
return queryAsObject(String.class, sql, parameters);
}
public Boolean queryAsBoolean(String sql, Object... parameters) {
return queryAsObject(Boolean.class, sql, parameters);
}
public Date queryAsDate(String sql, Object... parameters) {
return queryAsObject(Date.class, sql, parameters);
}
public Map<String, Object> queryAsMap(String sql, Object... parameters) {
return queryAsObject(Map.class, sql, parameters);
}
public <T> T[] queryAsArray(Class<T> arrayComponentClass, String sql, Object... parameters) {
try {
Class<T[]> clazz = (Class<T[]>) Class.forName("[" + arrayComponentClass.getName());
return queryAsObject(clazz, sql, parameters);
} catch (ClassNotFoundException e) {
throw new DbException(e);
}
}
public <T> Pagelist<T> queryAsPagelist(PageInfo pageInfo, Class<T> beanClass, String sql, Object... parameters) {
Validate.notNull(beanClass, "beanClass is null.");
RowMapper<T> rowMapper = getRowMapper(beanClass);
return queryAsPagelist(pageInfo, rowMapper, sql, parameters);
}
public <T> Pagelist<T> queryAsPagelist(PageInfo pageInfo, RowMapper<T> rowMapper, String sql, Object... parameters) {
Validate.notNull(pageInfo, "pageInfo is null.");
Validate.notNull(rowMapper, "rowMapper is null.");
PagelistImpl<T> pagelist = new PagelistImpl<T>(pageInfo);
if (pageInfo.getTotalCount() < 0) {
String count_sql = DbUtils.get_sql_select_count(sql);
int count = queryAsInt(count_sql, parameters);
pagelist.setTotalCount(count);
}
List<T> items = Collections.emptyList();
if (pagelist.getTotalCount() > 0) {
String page_sql = dialect.sql_pagelist(sql, pagelist.getFirstResult(), pagelist.getPageSize());
PagelistHandler<T> rsh = new PagelistHandler<T>(rowMapper);
if (page_sql == null) {
// 如果不支持分页,那么使用原始的分页方法 ResultSet.absolute(first)
rsh.setFirstResult(pagelist.getFirstResult());
} else {
// 使用数据库自身的分页SQL语句,将直接返回某一个
rsh.setFirstResult(0);
sql = page_sql;
}
rsh.setMaxResults(pagelist.getPageSize());
items = query(rsh, sql, parameters);
}
pagelist.setItems(items);
return pagelist;
}
public <T> T query(ResultSetHandler<T> rsh, String sql, Object... parameters) {
Validate.notNull(rsh, "rsh is null.");
Validate.notNull(sql, "sql is null.");
Connection conn = null;
PreparedStatement ps = null;
ResultSet rs = null;
T result = null;
try {
conn = getConnection();
ps = PreparedStatementCreator.createPreparedStatement(conn, sql, parameters);
rs = ps.executeQuery();
result = rsh.handle(rs);
} catch (SQLException e) {
throw new DbException(e).set("sql", sql).set("parameters", parameters);
} finally {
DbUtils.closeQuietly(rs);
DbUtils.closeQuietly(ps);
closeConnection(conn);
}
return result;
}
public int execute(String sql, Object... parameters) {
Validate.notNull(sql, "sql is null.");
Connection conn = null;
PreparedStatement ps = null;
int rows = 0;
try {
conn = getConnection();
ps = PreparedStatementCreator.createPreparedStatement(conn, sql, parameters);
rows = ps.executeUpdate();
} catch (SQLException e) {
throw new DbException(e).set("sql", sql).set("parameters", parameters);
} finally {
DbUtils.closeQuietly(ps);
closeConnection(conn);
}
return rows;
}
public int[] executeBatch(String sql, List<Object[]> parameters) {
Validate.notNull(sql, "sql is null.");
Connection conn = null;
PreparedStatement ps = null;
int[] rows;
try {
conn = getConnection();
ps = conn.prepareStatement(sql);
for (Object[] parameter : parameters) {
for (int i = 0; i < parameter.length; i++) {
ps.setObject(i + 1, parameter[i]);
}
ps.addBatch();
}
rows = ps.executeBatch();
} catch (SQLException e) {
throw new DbException(e).set("sql", sql).set("parameters", parameters);
} finally {
DbUtils.closeQuietly(ps);
closeConnection(conn);
}
return rows;
}
public void execute(ConnectionCallback callback) {
Connection conn = null;
try {
conn = getConnection();
callback.execute(conn);
} catch (SQLException e) {
throw new DbException(e);
} finally {
closeConnection(conn);
}
}
/**
* 判断表是否已经存在
*/
public boolean tableExist(String name) {
Connection conn = null;
ResultSet rs = null;
try {
conn = getConnection();
DatabaseMetaData metaData = conn.getMetaData();
rs = metaData.getTables(null, null, name.toUpperCase(), new String[] { "TABLE" });
return rs.next();
} catch (SQLException e) {
throw new DbException(e);
} finally {
DbUtils.closeQuietly(rs);
closeConnection(conn);
}
}
public SqlDialect getDialect() {
return dialect;
}
private SqlDialect doGetDialet() {
Connection conn = null;
try {
conn = dataSource.getConnection();
String name = conn.getMetaData().getDatabaseProductName();
return SqlDialect.getDialect(name);
} catch (SQLException e) {
throw new DbException(e);
} finally {
DbUtils.closeQuietly(conn);
}
}
public <T> RowMapper<T> getRowMapper(Class<T> beanClass) {
RowMapper<T> rowMapper;
if (beanClass.isArray()) {
rowMapper = (RowMapper<T>) new ArrayRowMapper();
} else if (beanClass.getName().equals("java.util.Map")) {
rowMapper = (RowMapper<T>) new MapRowMapper();
} else if (beanClass.getName().startsWith("java.")) {
rowMapper = new SingleColumnRowMapper<T>(beanClass);
} else {
rowMapper = new BeanRowMapper<T>(beanClass);
}
return rowMapper;
}
}