package com.taobao.tddl.group.jdbc;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.LinkedList;
import java.util.List;
import com.taobao.tddl.atom.jdbc.SqlMetaDataFactory;
import com.taobao.tddl.atom.jdbc.TStatement;
import com.taobao.tddl.common.jdbc.SqlTypeParser;
import com.taobao.tddl.common.model.SqlMetaData;
import com.taobao.tddl.common.model.SqlType;
import com.taobao.tddl.group.config.GroupIndex;
import com.taobao.tddl.group.dbselector.DBSelector.AbstractDataSourceTryer;
import com.taobao.tddl.group.dbselector.DBSelector.DataSourceTryer;
import com.taobao.tddl.group.utils.GroupHintParser;
import com.taobao.tddl.common.utils.logger.Logger;
import com.taobao.tddl.common.utils.logger.LoggerFactory;
/**
* @author linxuan
* @author yangzhu
*/
public class TGroupStatement implements TStatement {
private static final Logger log = LoggerFactory.getLogger(TGroupStatement.class);
protected TGroupConnection tGroupConnection;
protected TGroupDataSource tGroupDataSource;
protected int retryingTimes;
public TGroupStatement(TGroupDataSource tGroupDataSource, TGroupConnection tGroupConnection){
this.tGroupDataSource = tGroupDataSource;
this.tGroupConnection = tGroupConnection;
this.retryingTimes = tGroupDataSource.getRetryingTimes();
}
/*
* ========================================================================
* 下层(有可能不是真正的)Statement的持有,getter/setter包权限
* ======================================================================
*/
private Statement baseStatement;
/**
* 设置在底层执行的具体的Statement 如果前面的baseStatement未关,则先关闭
*
* @param baseStatement
*/
void setBaseStatement(Statement baseStatement) {
if (this.baseStatement != null) {
try {
this.baseStatement.close();
} catch (SQLException e) {
log.error("close baseStatement failed.", e);
}
}
this.baseStatement = baseStatement;
}
/**
* query time out . 超时时间,如果超时时间不为0。那么超时应该被set到真正的query中。
*/
protected int queryTimeout = 0;
protected int fetchSize;
protected int maxRows;
/**
* 经过计算后的结果集,允许使用 getResult函数调用. 一个statement只允许有一个结果集
*/
protected ResultSet currentResultSet;
/**
* 更新计数,如果执行了多次,那么这个值只会返回最后一次执行的结果。 如果是一个query,那么返回的数据应该是-1
*/
protected int updateCount;
protected int resultSetType = ResultSet.TYPE_FORWARD_ONLY; ;
protected int resultSetConcurrency = ResultSet.CONCUR_READ_ONLY;
// jdbc规范中未指明resultSetHoldability的默认值,要设成ResultSet.CLOSE_CURSORS_AT_COMMIT吗?
// TODO 统一设成-1吗?
protected int resultSetHoldability = -1;
/**
* sql元信息持有
*/
protected SqlMetaData sqlMetaData = null;
public boolean execute(String sql) throws SQLException {
return executeInternal(sql, -1, null, null);
}
public boolean execute(String sql, int autoGeneratedKeys) throws SQLException {
return executeInternal(sql, autoGeneratedKeys, null, null);
}
public boolean execute(String sql, int[] columnIndexes) throws SQLException {
return executeInternal(sql, -1, columnIndexes, null);
}
public boolean execute(String sql, String[] columnNames) throws SQLException {
return executeInternal(sql, -1, null, columnNames);
}
// jdbc规范: 返回true表示executeQuery,false表示executeUpdate
private boolean executeInternal(String sql, int autoGeneratedKeys, int[] columnIndexes, String[] columnNames)
throws SQLException {
if (SqlTypeParser.isQuerySql(sql)) {
executeQuery(sql);
return true;
} else {
if (autoGeneratedKeys == -1 && columnIndexes == null && columnNames == null) {
executeUpdate(sql);
} else if (autoGeneratedKeys != -1) {
executeUpdate(sql, autoGeneratedKeys);
} else if (columnIndexes != null) {
executeUpdate(sql, columnIndexes);
} else if (columnNames != null) {
executeUpdate(sql, columnNames);
} else {
executeUpdate(sql);
}
return false;
}
}
/*
* ========================================================================
* executeUpdate逻辑
* ======================================================================
*/
public int executeUpdate(String sql) throws SQLException {
return executeUpdateInternal(sql, -1, null, null);
}
public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException {
return executeUpdateInternal(sql, autoGeneratedKeys, null, null);
}
public int executeUpdate(String sql, int[] columnIndexes) throws SQLException {
return executeUpdateInternal(sql, -1, columnIndexes, null);
}
public int executeUpdate(String sql, String[] columnNames) throws SQLException {
return executeUpdateInternal(sql, -1, null, columnNames);
}
private int executeUpdateInternal(String sql, int autoGeneratedKeys, int[] columnIndexes, String[] columnNames)
throws SQLException {
checkClosed();
ensureResultSetIsEmpty();
Connection conn = tGroupConnection.getBaseConnection(sql, false);
if (conn != null) {
sql = GroupHintParser.removeTddlGroupHint(sql);
this.updateCount = executeUpdateOnConnection(conn, sql, autoGeneratedKeys, columnIndexes, columnNames);
return this.updateCount;
} else {
GroupIndex dataSourceIndex = GroupHintParser.convertHint2Index(sql);
sql = GroupHintParser.removeTddlGroupHint(sql);
if (dataSourceIndex == null) {
dataSourceIndex = ThreadLocalDataSourceIndex.getIndex();
}
this.updateCount = this.tGroupDataSource.getDBSelector(false).tryExecute(null,
executeUpdateTryer,
retryingTimes,
sql,
autoGeneratedKeys,
columnIndexes,
columnNames,
dataSourceIndex);
return this.updateCount;
}
}
private int executeUpdateOnConnection(Connection conn, String sql, int autoGeneratedKeys, int[] columnIndexes,
String[] columnNames) throws SQLException {
Statement stmt = createStatementInternal(conn, sql, false);
if (autoGeneratedKeys == -1 && columnIndexes == null && columnNames == null) {
return stmt.executeUpdate(sql);
} else if (autoGeneratedKeys != -1) {
return stmt.executeUpdate(sql, autoGeneratedKeys);
} else if (columnIndexes != null) {
return stmt.executeUpdate(sql, columnIndexes);
} else if (columnNames != null) {
return stmt.executeUpdate(sql, columnNames);
} else {
return stmt.executeUpdate(sql);
}
}
private DataSourceTryer<Integer> executeUpdateTryer = new AbstractDataSourceTryer<Integer>() {
public Integer tryOnDataSource(DataSourceWrapper dsw,
Object... args)
throws SQLException {
Connection conn = TGroupStatement.this.tGroupConnection.createNewConnection(dsw,
false);
return executeUpdateOnConnection(conn,
(String) args[0],
(Integer) args[1],
(int[]) args[2],
(String[]) args[3]);
}
};
/**
* 会调用setBaseStatement以关闭已有的Statement
*/
private Statement createStatementInternal(Connection conn, String sql, boolean isBatch) throws SQLException {
Statement stmt;
if (isBatch) {
stmt = conn.createStatement();
} else {
int resultSetHoldability = this.resultSetHoldability;
if (resultSetHoldability == -1) {// 未调用过setResultSetHoldability
resultSetHoldability = conn.getHoldability();
}
stmt = conn.createStatement(this.resultSetType, this.resultSetConcurrency, resultSetHoldability);
}
setBaseStatement(stmt); // 会关闭已有的Statement
stmt.setQueryTimeout(queryTimeout); // 这句也有可能抛出异常,放在最后
stmt.setFetchSize(fetchSize);
stmt.setMaxRows(maxRows);
// 填充sql元信息
fillSqlMetaData(stmt, sql);
return stmt;
}
/*
* ========================================================================
* executeBatch
* ======================================================================
*/
protected List<String> batchedArgs;
public void addBatch(String sql) throws SQLException {
checkClosed();
if (batchedArgs == null) {
batchedArgs = new LinkedList<String>();
}
if (sql != null) {
batchedArgs.add(sql);
}
}
public void clearBatch() throws SQLException {
checkClosed();
if (batchedArgs != null) {
batchedArgs.clear();
}
}
public int[] executeBatch() throws SQLException {
try {
checkClosed();
ensureResultSetIsEmpty();
if (batchedArgs == null || batchedArgs.isEmpty()) {
return new int[0];
}
Connection conn = tGroupConnection.getBaseConnection(null, false);
if (conn != null) {
// 如果当前已经有连接,则不做任何重试。对于更新来说,不管有没有事务,
// 用户总期望getConnection获得连接之后,后续的一系列操作都在这同一个库,同一个连接上执行
return executeBatchOnConnection(conn, this.batchedArgs);
} else {
return tGroupDataSource.getDBSelector(false).tryExecute(null, executeBatchTryer, retryingTimes);
}
} finally {
if (batchedArgs != null) {
batchedArgs.clear();
}
}
}
private DataSourceTryer<int[]> executeBatchTryer = new AbstractDataSourceTryer<int[]>() {
public int[] tryOnDataSource(DataSourceWrapper dsw,
Object... args)
throws SQLException {
Connection conn = TGroupStatement.this.tGroupConnection.createNewConnection(dsw,
false);
return executeBatchOnConnection(conn,
TGroupStatement.this.batchedArgs);
}
};
private int[] executeBatchOnConnection(Connection conn, List<String> batchedSqls) throws SQLException {
Statement stmt = createStatementInternal(conn, batchedSqls.get(0), true);
for (String sql : batchedSqls) {
stmt.addBatch(sql);
}
return stmt.executeBatch();
}
/*
* ========================================================================
* 关闭逻辑
* ======================================================================
*/
protected boolean closed; // 当前statment 是否是关闭的
public void close() throws SQLException {
close(true);
}
void close(boolean removeThis) throws SQLException {
if (closed) {
return;
}
closed = true;
try {
if (currentResultSet != null) {
currentResultSet.close();
}
} catch (SQLException e) {
log.warn("Close currentResultSet failed.", e);
} finally {
currentResultSet = null;
}
try {
if (this.baseStatement != null) {
this.baseStatement.close();
}
} finally {
this.baseStatement = null;
if (removeThis) {
tGroupConnection.removeOpenedStatements(this);
}
}
}
protected void checkClosed() throws SQLException {
if (closed) {
throw new SQLException("No operations allowed after statement closed.");
}
}
/**
* 如果新建了查询,那么上一次查询的结果集应该被显示的关闭掉。这才是符合jdbc规范的
*
* @throws SQLException
*/
protected void ensureResultSetIsEmpty() throws SQLException {
if (currentResultSet != null) {
// log.debug("result set is not null,close current result set");
try {
currentResultSet.close();
} catch (SQLException e) {
log.error("exception on close last result set . can do nothing..", e);
} finally {
// 最终要显示的关闭它
currentResultSet = null;
}
}
}
/*
* ========================================================================
* executeQuery 查询逻辑
* ======================================================================
*/
public ResultSet executeQuery(String sql) throws SQLException {
checkClosed();
ensureResultSetIsEmpty();
boolean gotoRead = SqlType.SELECT.equals(SqlTypeParser.getSqlType(sql)) && tGroupConnection.getAutoCommit();
Connection conn = tGroupConnection.getBaseConnection(sql, gotoRead);
if (conn != null) {
sql = GroupHintParser.removeTddlGroupHint(sql);
return executeQueryOnConnection(conn, sql);
} else {
// hint优先
GroupIndex dataSourceIndex = GroupHintParser.convertHint2Index(sql);
sql = GroupHintParser.removeTddlGroupHint(sql);
if (dataSourceIndex == null) {
dataSourceIndex = ThreadLocalDataSourceIndex.getIndex();
}
return this.tGroupDataSource.getDBSelector(gotoRead).tryExecute(executeQueryTryer,
retryingTimes,
sql,
dataSourceIndex);
}
}
protected ResultSet executeQueryOnConnection(Connection conn, String sql) throws SQLException {
Statement stmt = createStatementInternal(conn, sql, false);
this.currentResultSet = stmt.executeQuery(sql);
return this.currentResultSet;
}
protected DataSourceTryer<ResultSet> executeQueryTryer = new AbstractDataSourceTryer<ResultSet>() {
public ResultSet tryOnDataSource(DataSourceWrapper dsw,
Object... args)
throws SQLException {
String sql = (String) args[0];
Connection conn = TGroupStatement.this.tGroupConnection.createNewConnection(dsw,
true);
return executeQueryOnConnection(conn, sql);
}
};
public SQLWarning getWarnings() throws SQLException {
checkClosed();
if (baseStatement != null) return baseStatement.getWarnings();
return null;
}
public void clearWarnings() throws SQLException {
checkClosed();
if (baseStatement != null) baseStatement.clearWarnings();
}
/*
* ========================================================================
* 以下为简单支持的方法
* ======================================================================
*/
/**
* 貌似是只有存储过程中会出现多结果集 因此不支持
*/
protected boolean moreResults;
public boolean getMoreResults() throws SQLException {
return moreResults;
}
public int getQueryTimeout() throws SQLException {
return queryTimeout;
}
public void setQueryTimeout(int queryTimeout) throws SQLException {
this.queryTimeout = queryTimeout;
}
public ResultSet getResultSet() throws SQLException {
return currentResultSet;
}
public int getUpdateCount() throws SQLException {
return updateCount;
}
public int getResultSetConcurrency() throws SQLException {
return resultSetConcurrency;
}
public int getResultSetHoldability() throws SQLException {
return resultSetHoldability;
}
public int getResultSetType() throws SQLException {
return resultSetType;
}
public void setResultSetType(int resultSetType) {
this.resultSetType = resultSetType;
}
public void setResultSetConcurrency(int resultSetConcurrency) {
this.resultSetConcurrency = resultSetConcurrency;
}
public void setResultSetHoldability(int resultSetHoldability) {
this.resultSetHoldability = resultSetHoldability;
}
public Connection getConnection() throws SQLException {
return tGroupConnection;
}
public void cancel() throws SQLException {
// 调用底层进行关闭
// see com.mysql.jdbc.StatementImpl
this.baseStatement.cancel();
}
/*
* ========================================================================
* 以下为不支持的方法
* ======================================================================
*/
public int getFetchDirection() throws SQLException {
throw new UnsupportedOperationException("getFetchDirection");
}
public int getFetchSize() throws SQLException {
return this.fetchSize;
}
public int getMaxFieldSize() throws SQLException {
throw new UnsupportedOperationException("getMaxFieldSize");
}
public int getMaxRows() throws SQLException {
return this.maxRows;
}
public void setCursorName(String cursorName) throws SQLException {
throw new UnsupportedOperationException("setCursorName");
}
public void setEscapeProcessing(boolean escapeProcessing) throws SQLException {
throw new UnsupportedOperationException("setEscapeProcessing");
}
public boolean getMoreResults(int current) throws SQLException {
throw new UnsupportedOperationException("getMoreResults");
}
public void setFetchDirection(int fetchDirection) throws SQLException {
throw new UnsupportedOperationException("setFetchDirection");
}
public void setFetchSize(int fetchSize) throws SQLException {
this.fetchSize = fetchSize;
}
public void setMaxFieldSize(int maxFieldSize) throws SQLException {
throw new UnsupportedOperationException("setMaxFieldSize");
}
public void setMaxRows(int maxRows) throws SQLException {
this.maxRows = maxRows;
}
public ResultSet getGeneratedKeys() throws SQLException {
if (this.baseStatement != null) {
return this.baseStatement.getGeneratedKeys();
} else {
throw new SQLException("在调用getGeneratedKeys前未执行过任何更新操作");
}
// throw new UnsupportedOperationException("getGeneratedKeys");
}
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return this.getClass().isAssignableFrom(iface);
}
@SuppressWarnings("unchecked")
public <T> T unwrap(Class<T> iface) throws SQLException {
try {
return (T) this;
} catch (Exception e) {
throw new SQLException(e);
}
}
public boolean isClosed() throws SQLException {
throw new SQLException("not support exception");
}
public void setPoolable(boolean poolable) throws SQLException {
throw new SQLException("not support exception");
}
public boolean isPoolable() throws SQLException {
throw new SQLException("not support exception");
}
protected void fillSqlMetaData(Statement statement, String sql) {
if (statement instanceof TStatement) fillSqlMetaData((TStatement) statement, sql);
}
protected void fillSqlMetaData(TStatement statement, String sql) {
if (this.sqlMetaData == null) {
this.sqlMetaData = SqlMetaDataFactory.getSqlMetaData(sql);
}
statement.fillMetaData(this.sqlMetaData);
}
@Override
public void fillMetaData(SqlMetaData sqlMetaData) {
this.sqlMetaData = sqlMetaData;
}
@Override
public SqlMetaData getSqlMetaData() {
return this.sqlMetaData;
}
}