package com.xiaoleilu.hutool.db;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.RowId;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map.Entry;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.sql.DataSource;
import com.xiaoleilu.hutool.db.dialect.Dialect;
import com.xiaoleilu.hutool.db.dialect.DialectFactory;
import com.xiaoleilu.hutool.db.ds.DSFactory;
import com.xiaoleilu.hutool.db.meta.Column;
import com.xiaoleilu.hutool.db.meta.Table;
import com.xiaoleilu.hutool.db.sql.Condition;
import com.xiaoleilu.hutool.db.sql.Condition.LikeType;
import com.xiaoleilu.hutool.db.sql.SqlFormatter;
import com.xiaoleilu.hutool.log.Log;
import com.xiaoleilu.hutool.log.StaticLog;
import com.xiaoleilu.hutool.util.ArrayUtil;
import com.xiaoleilu.hutool.util.CharsetUtil;
import com.xiaoleilu.hutool.util.StrUtil;
/**
* 数据库操作工具类
*
* @author Luxiaolei
*
*/
public final class DbUtil {
private final static Log log = StaticLog.get();
private DbUtil() {}
/**
* 实例化一个新的SQL运行对象
*
* @param dialect 数据源
* @return SQL执行类
*/
public static SqlConnRunner newSqlConnRunner(Dialect dialect) {
return SqlConnRunner.create(dialect);
}
/**
* 实例化一个新的SQL运行对象
*
* @param ds 数据源
* @return SQL执行类
*/
public static SqlConnRunner newSqlConnRunner(DataSource ds) {
return SqlConnRunner.create(ds);
}
/**
* 实例化一个新的SQL运行对象
*
* @param conn 数据库连接对象
* @return SQL执行类
*/
public static SqlConnRunner newSqlConnRunner(Connection conn) {
return SqlConnRunner.create(DialectFactory.newDialect(conn));
}
/**
* 实例化一个新的SQL运行对象,使用默认数据源
*
* @return SQL执行类
*/
public static SqlRunner newSqlRunner() {
return SqlRunner.create(getDs());
}
/**
* 实例化一个新的SQL运行对象
*
* @param ds 数据源
* @return SQL执行类
*/
public static SqlRunner newSqlRunner(DataSource ds) {
return SqlRunner.create(ds);
}
/**
* 实例化一个新的SQL运行对象
*
* @param ds 数据源
* @param dialect SQL方言
* @return SQL执行类
*/
public static SqlRunner newSqlRunner(DataSource ds, Dialect dialect) {
return SqlRunner.create(ds, dialect);
}
/**
* 新建数据库会话,使用默认数据源
* @return 数据库会话
*/
public static Session newSession(){
return Session.create(getDs());
}
/**
* 新建数据库会话
* @param ds 数据源
* @return 数据库会话
*/
public static Session newSession(DataSource ds){
return Session.create(ds);
}
/**
* 新建数据库会话
* @param conn 数据库连接对象
* @return 数据库会话
*/
public static Session newSession(Connection conn){
return Session.create(conn);
}
/**
* 连续关闭一系列的SQL相关对象<br/>
* 这些对象必须按照顺序关闭,否则会出错。
*
* @param objsToClose 需要关闭的对象
*/
public static void close(Object... objsToClose) {
for (Object obj : objsToClose) {
try {
if (obj != null) {
if (obj instanceof ResultSet) {
((ResultSet) obj).close();
} else if (obj instanceof Statement) {
((Statement) obj).close();
} else if (obj instanceof PreparedStatement) {
((PreparedStatement) obj).close();
} else if (obj instanceof Connection) {
((Connection) obj).close();
} else {
log.warn("Object " + obj.getClass().getName() + " not a ResultSet or Statement or PreparedStatement or Connection!");
}
}
} catch (SQLException e) {
}
}
}
/**
* 获得默认数据源
* @return 默认数据源
*/
public static DataSource getDs(){
return DSFactory.get();
}
/**
* 获取指定分组的数据源
* @param group 分组
* @return 数据源
*/
public static DataSource getDs(String group){
return DSFactory.get(group);
}
/**
* 获得JNDI数据源
* @param jndiName JNDI名称
* @return 数据源
*/
public static DataSource getJndiDsWithLog(String jndiName) {
try {
return getJndiDs(jndiName);
} catch (DbRuntimeException e) {
log.error(e.getCause(), "Find JNDI datasource error!");
}
return null;
}
/**
* 获得JNDI数据源
* @param jndiName JNDI名称
* @return 数据源
*/
public static DataSource getJndiDs(String jndiName) {
try {
return (DataSource) new InitialContext().lookup(jndiName);
} catch (NamingException e) {
throw new DbRuntimeException(e);
}
}
/**
* 获得所有表名
*/
public static List<String> getTables(DataSource ds) {
final List<String> tables = new ArrayList<String>();
Connection conn = null;
ResultSet rs = null;
try {
conn = ds.getConnection();
final DatabaseMetaData metaData = conn.getMetaData();
rs = metaData.getTables(conn.getCatalog(), null, null, new String[]{"TABLES"});
if(rs == null) {
return null;
}
while(rs.next()) {
final String table = rs.getString("TABLE_NAME");
if(StrUtil.isBlank(table) == false) {
tables.add(table);
}
}
} catch (Exception e) {
throw new DbRuntimeException("Get tables error!", e);
}finally {
close(rs, conn);
}
return tables;
}
/**
* 获得结果集的所有列名
* @param rs 结果集
* @return 列名数组
*/
public static String[] getColumnNames(ResultSet rs) {
try {
ResultSetMetaData rsmd = rs.getMetaData();
int columnCount = rsmd.getColumnCount();
String[] labelNames = new String[columnCount];
for (int i=0; i<labelNames.length; i++) {
labelNames[i] = rsmd.getColumnLabel(i +1);
}
return labelNames;
} catch (Exception e) {
throw new DbRuntimeException("Get colunms error!", e);
}
}
/**
* 获得表的所有列名
* @param ds 数据源
* @param tableName 表名
* @return 列数组
* @throws SQLException
*/
public static String[] getColumnNames(DataSource ds, String tableName) {
List<String> columnNames = new ArrayList<String>();
Connection conn = null;
ResultSet rs = null;
try {
conn = ds.getConnection();
final DatabaseMetaData metaData = conn.getMetaData();
rs = metaData.getColumns(conn.getCatalog(), null, tableName, null);
while(rs.next()) {
columnNames.add(rs.getString("COLUMN_NAME"));
}
return columnNames.toArray(new String[columnNames.size()]);
} catch (Exception e) {
throw new DbRuntimeException("Get columns error!", e);
}finally {
close(rs, conn);
}
}
/**
* 创建带有字段限制的Entity对象<br>
* 此方法读取数据库中对应表的字段列表,加入到Entity中,当Entity被设置内容时,会忽略对应表字段外的所有KEY
* @param ds 数据源
* @param tableName 表名
* @return Entity对象
*/
public static Entity createLimitedEntity(DataSource ds, String tableName){
String[] columnNames = getColumnNames(ds, tableName);
return Entity.create(tableName).setFieldNames(columnNames);
}
/**
* 获得表的元信息
* @param ds 数据源
* @param tableName 表名
* @return Table对象
*/
@SuppressWarnings("resource")
public static Table getTableMeta(DataSource ds, String tableName) {
final Table table = Table.create(tableName);
Connection conn = null;
ResultSet rs = null;
try {
conn = ds.getConnection();
final DatabaseMetaData metaData = conn.getMetaData();
//获得主键
rs = metaData.getPrimaryKeys(conn.getCatalog(), null, tableName);
while(rs.next()) {
table.addPk("COLUMN_NAME");
}
//获得列
rs = metaData.getColumns(conn.getCatalog(), null, tableName, null);
while(rs.next()) {
table.setColumn(Column.create(tableName, rs));
}
} catch (Exception e) {
throw new DbRuntimeException("Get columns error!", e);
}finally {
close(rs, conn);
}
return table;
}
/**
* 填充SQL的参数。
*
* @param ps PreparedStatement
* @param params SQL参数
* @throws SQLException
*/
public static void fillParams(PreparedStatement ps, Collection<Object> params) throws SQLException {
fillParams(ps, params.toArray(new Object[params.size()]));
}
/**
* 填充SQL的参数。
*
* @param ps PreparedStatement
* @param params SQL参数
* @throws SQLException
*/
public static void fillParams(PreparedStatement ps, Object... params) throws SQLException {
if (ArrayUtil.isEmpty(params)) {
return;//无参数
}
ParameterMetaData pmd = ps.getParameterMetaData();
for (int i = 0; i < params.length; i++) {
int paramIndex = i + 1;
if (params[i] != null) {
ps.setObject(paramIndex, params[i]);
} else {
int sqlType = Types.VARCHAR;
try {
sqlType = pmd.getParameterType(paramIndex);
} catch (SQLException e) {
log.warn("Param get type fail, by: " + e.getMessage());
}
ps.setNull(paramIndex, sqlType);
}
}
}
/**
* 获得自增键的值<br>
* 此方法对于Oracle无效
* @param ps PreparedStatement
* @return 自增键的值
* @throws SQLException
*/
public static Long getGeneratedKeyOfLong(PreparedStatement ps) throws SQLException {
ResultSet rs = null;
try {
rs = ps.getGeneratedKeys();
Long generatedKey = null;
if(rs != null && rs.next()) {
try{
generatedKey = rs.getLong(1);
}catch (SQLException e){
//自增主键不为数字或者为Oracle的rowid,跳过
}
}
return generatedKey;
} catch (SQLException e) {
throw e;
}finally {
close(rs);
}
}
/**
* 获得所有主键<br>
* @param ps PreparedStatement
* @return 所有主键
* @throws SQLException
*/
public static List<Object> getGeneratedKeys(PreparedStatement ps) throws SQLException {
List<Object> keys = new ArrayList<Object>();
ResultSet rs = null;
int i=1;
try {
rs = ps.getGeneratedKeys();
if(rs != null && rs.next()) {
keys.add(rs.getObject(i++));
}
return keys;
} catch (SQLException e) {
throw e;
}finally {
close(rs);
}
}
/**
* 构件相等条件的where语句<br>
* 如果没有条件语句,泽返回空串,表示没有条件
* @param entity 条件实体
* @param paramValues 条件值得存放List
* @return 带where关键字的SQL部分
*/
public static String buildEqualsWhere(Entity entity, List<Object> paramValues) {
if(null == entity || entity.isEmpty()) {
return StrUtil.EMPTY;
}
final StringBuilder sb = new StringBuilder(" WHERE ");
boolean isNotFirst = false;
for (Entry<String, Object> entry : entity.entrySet()) {
if(isNotFirst) {
sb.append(" and ");
}else {
isNotFirst = true;
}
sb.append("`").append(entry.getKey()).append("`").append(" = ?");
paramValues.add(entry.getValue());
}
return sb.toString();
}
/**
* 通过实体对象构建条件对象
* @param entity 实体对象
* @return 条件对象
*/
public static Condition[] buildConditions(Entity entity){
if(null == entity || entity.isEmpty()) {
return null;
}
final Condition[] conditions = new Condition[entity.size()];
int i = 0;
for (Entry<String, Object> entry : entity.entrySet()) {
conditions[i++] = new Condition(entry.getKey(), entry.getValue());
}
return conditions;
}
/**
* 创建LIKE语句中的值
* @param value 被查找值
* @param likeType LIKE值类型 {@link LikeType}
* @return 拼接后的like值
*/
public static String buildLikeValue(String value, LikeType likeType){
StringBuilder likeValue = StrUtil.builder("LIKE ");
switch (likeType) {
case StartWith:
likeValue.append('%').append(value);
break;
case EndWith:
likeValue.append(value).append('%');
break;
case Contains:
likeValue.append('%').append(value).append('%');
break;
default:
break;
}
return likeValue.toString();
}
/**
* 识别JDBC驱动名
* @param nameContainsProductInfo 包含数据库标识的字符串
* @return 驱动
*/
public static String identifyDriver(String nameContainsProductInfo) {
if(StrUtil.isBlank(nameContainsProductInfo)) {
return null;
}
nameContainsProductInfo = nameContainsProductInfo.toLowerCase();
String driver = null;
if(nameContainsProductInfo.contains("mysql")) {
driver = DialectFactory.DRIVER_MYSQL;
}else if(nameContainsProductInfo.contains("oracle")) {
driver = DialectFactory.DRIVER_ORACLE;
}else if(nameContainsProductInfo.contains("postgresql")) {
driver = DialectFactory.DRIVER_POSTGRESQL;
}else if(nameContainsProductInfo.contains("sqlite")) {
driver = DialectFactory.DRIVER_SQLLITE3;
}
return driver;
}
/**
* 识别JDBC驱动名
* @param ds 数据源
* @return 驱动
*/
public static String identifyDriver(DataSource ds) {
Connection conn = null;
String driver = null;
try {
conn = ds.getConnection();
driver = identifyDriver(conn);
} catch (Exception e) {
throw new DbRuntimeException("Identify driver error!", e);
}finally {
close(conn);
}
return driver;
}
/**
* 识别JDBC驱动名
* @param conn 数据库连接对象
* @return 驱动
*/
public static String identifyDriver(Connection conn) {
String driver = null;
try {
DatabaseMetaData meta = conn.getMetaData();
driver = identifyDriver(meta.getDatabaseProductName());
if(StrUtil.isBlank(driver)) {
driver = identifyDriver(meta.getDriverName());
}
} catch (SQLException e) {
throw new DbRuntimeException("Identify driver error!", e);
}
return driver;
}
/**
* 验证实体类对象的有效性
* @param entity 实体类对象
*/
public static void validateEntity(Entity entity){
if(null == entity) {
throw new DbRuntimeException("Entity is null !");
}
if(StrUtil.isBlank(entity.getTableName())) {
throw new DbRuntimeException("Entity`s table name is null !");
}
if(entity.isEmpty()) {
throw new DbRuntimeException("No filed and value in this entity !");
}
}
/**
* 将RowId转为字符串
* @param rowId RowId
* @return RowId字符串
*/
public static String rowIdToString(RowId rowId){
return StrUtil.str(rowId.getBytes(), CharsetUtil.CHARSET_ISO_8859_1);
}
/**
* 格式化SQL
* @param sql SQL
* @return 格式化后的SQL
*/
public static String formatSql(String sql){
return SqlFormatter.format(sql);
}
//---------------------------------------------------------------------------- Private method start
//---------------------------------------------------------------------------- Private method end
}