package com.taobao.tddl.common.mock;
import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.sql.DataSource;
import com.taobao.tddl.common.exception.NotSupportException;
import com.taobao.tddl.common.model.DBType;
public class MockDataSource implements DataSource, Cloneable {
private int timeToObtainConnection = 0;
private int getConnectionInvokingTimes = 0;
private String name;
private String dbIndex;
private boolean isClosed;
private DBType dbType = DBType.MYSQL;
public MockDataSource(){
}
public MockDataSource(String dbIndex, String name){
this.dbIndex = dbIndex;
this.name = name;
}
public static class ExecuteInfo {
public ExecuteInfo(MockDataSource dataSource, String method, String sql, Object[] args){
this.ds = dataSource;
this.method = method;
this.sql = sql;
this.args = args;
}
public MockDataSource ds;
public String method;
public String sql;
public Object[] args;
@Override
public String toString() {
return new StringBuilder("ExecuteInfo:{ds:").append(ds)
.append(",method:")
.append(method)
.append(",sql:")
.append(sql)
.append(",args:")
.append(Arrays.toString(args))
.append("}")
.toString();
}
}
public static class QueryResult {
public QueryResult(Map<String, Integer> columns, List<Object[]> values){
this.columns = columns;
this.rows = values;
}
/**
* 只支持放一行数据,数据类型只支持数值long和String,例如: sku_id:0,item_id:65,seller_id:63
* sku_id:0,item_id:65,name:'aaa'
*/
public QueryResult(String row){
String[] cols = row.split(",");
this.columns = new HashMap<String, Integer>(cols.length);
List<Object> colvalues = new ArrayList<Object>(cols.length);
for (int i = 0; i < cols.length; i++) {
String col = cols[i];
String[] nv = col.split("\\:");
this.columns.put(nv[0], i);
if (nv[1].startsWith("'") && nv[1].endsWith("'")) {
colvalues.add(nv[1].substring(1, nv[1].length() - 1));// 字符串
} else if (nv[1].endsWith("NULL")) {
colvalues.add(null);
} else {
colvalues.add(Long.parseLong(nv[1]));// 数字
}
}
this.rows = new ArrayList<Object[]>(1);
this.rows.add(colvalues.toArray(new Object[colvalues.size()]));
}
public final Map<String, Integer> columns;
public final List<Object[]> rows;
}
public void checkState() throws SQLException {
if (isClosed) {
throw genFatalSQLException();
}
}
public SQLException genFatalSQLException() throws SQLException {
if (DBType.MYSQL.equals(dbType)) {
return new SQLException("dsClosed", "08001");// 来自MySQLExceptionSorter
} else if (DBType.ORACLE.equals(dbType)) {
return new SQLException("dsClosed", "28");// 来自OracleExceptionSorter
// //28 session has been
// killed
} else {
throw new RuntimeException("有了新的dbType而这里没有更新");
}
}
/**
* 存放每次执行的结果信息:实际的sql,参数,数据源名称
*/
private static ThreadLocal<ExecuteInfo> RESULT = new ThreadLocal<ExecuteInfo>(); // TODO
// 有了TRACE不需要这个了
private static ThreadLocal<List<ExecuteInfo>> TRACE = new ThreadLocal<List<ExecuteInfo>>();
private static ThreadLocal<List<QueryResult>> PREDATA = new ThreadLocal<List<QueryResult>>();
private static ThreadLocal<List<Integer>> PREAffectedRow = new ThreadLocal<List<Integer>>();
/**
* map中key的取值"getConnection"、"prepareStatement"、"executeQuery"、
* "executeUpdate"、"" ...
*/
private static ThreadLocal<Map<String, List<SQLException>>> PREException = new ThreadLocal<Map<String, List<SQLException>>>() {
@Override
protected Map<String, List<SQLException>> initialValue() {
Map<String, List<SQLException>> exceptions = new HashMap<String, List<SQLException>>(4);
exceptions.put(m_getConnection,
new ArrayList<SQLException>(0));
exceptions.put(m_prepareStatement,
new ArrayList<SQLException>(0));
exceptions.put(m_createStatement,
new ArrayList<SQLException>(0));
exceptions.put(m_executeQuery,
new ArrayList<SQLException>(0));
exceptions.put(m_executeUpdate,
new ArrayList<SQLException>(0));
return exceptions;
}
};
// 下面这些变量改成一个enum类
public static final String m_getConnection = "getConnection";
public static final String m_prepareStatement = "prepareStatement";
public static final String m_createStatement = "createStatement";
public static final String m_executeQuery = "executeQuery";
public static final String m_executeUpdate = "executeUpdate";
/**
* 需要在每个testcase的afterclass中调用这个方法
*/
public static void reset() {
RESULT.set(null);
TRACE.set(null);
PREDATA.set(null);
PREAffectedRow.set(null);
for (Map.Entry<String, List<SQLException>> e : PREException.get().entrySet()) {
e.getValue().clear();
}
}
public static void clearTrace() {
List<ExecuteInfo> trace = TRACE.get();
if (trace != null) trace.clear();
PREDATA.set(null);
}
public static void showTrace() {
showTrace("");
}
public static void showTrace(String msg) {
List<ExecuteInfo> trace = TRACE.get();
if (trace == null) {
return;
}
System.out.println("Invoke trace on MockDataSource:" + msg);
for (ExecuteInfo info : trace) {
System.out.println(info.toString());
}
}
public static ExecuteInfo getResult() {
return RESULT.get();
}
public static List<ExecuteInfo> getTrace() {
return TRACE.get();
}
// 是否在指定的dbIndex上执行过sqlHead开头的sql
public static boolean hasTrace(String dbIndex, String sqlHead) {
List<ExecuteInfo> trace = TRACE.get();
if (trace != null) {
for (ExecuteInfo info : trace) {
if (info.sql != null && dbIndex.equals(info.ds.dbIndex) && sqlHead.length() <= info.sql.length()
&& sqlHead.equalsIgnoreCase(info.sql.substring(0, sqlHead.length()))) {
return true;
}
}
}
return false;
}
// 是否在指定的dbIndex和dsName上执行过sqlHead开头的sql
public static boolean hasTrace(String dbIndex, String dsName, String sqlHead) {
List<ExecuteInfo> trace = TRACE.get();
if (trace != null) {
for (ExecuteInfo info : trace) {
if (info.sql != null && dbIndex.equals(info.ds.dbIndex) && info.ds.name.equals(dsName)
&& sqlHead.length() <= info.sql.length()
&& sqlHead.equalsIgnoreCase(info.sql.substring(0, sqlHead.length()))) {
return true;
}
}
}
return false;
}
public static boolean hasMethod(String dbIndex, String method) {
List<ExecuteInfo> trace = TRACE.get();
if (trace != null) {
for (ExecuteInfo info : trace) {
if (dbIndex.equals(info.ds.dbIndex) && method.equals(info.method)) {
return true;
}
}
}
return false;
}
public static boolean hasMethod(String dbIndex, String dsName, String method) {
List<ExecuteInfo> trace = TRACE.get();
if (trace != null) {
for (ExecuteInfo info : trace) {
if (dbIndex.equals(info.ds.dbIndex) && method.equals(info.method) && info.ds.name.equals(dsName)) {
return true;
}
}
}
return false;
}
/*
* public static List<QueryResult> getPreData(){ return PREDATA.get(); }
*/
/**
* 记录一个Datasource、Connection、Statement上的执行动作
*/
public static void record(ExecuteInfo info) {
RESULT.set(info);
if (TRACE.get() == null) {
TRACE.set(new ArrayList<ExecuteInfo>());
}
TRACE.get().add(info);
}
/**
* 加入一个预置的查询结果。返回ResultSet时,会按顺序提取预置的数据构造ResultSet
*
* @param arow 格式:sku_id:0,item_id:65,seller_id:63,name:'尺码'
*/
public static void addPreData(String arow) {
addPreData(new QueryResult(arow));
}
public static void addPreData(QueryResult queryResult) {
if (PREDATA.get() == null) {
PREDATA.set(new ArrayList<QueryResult>(5));
}
PREDATA.get().add(queryResult);
}
public static void addPreAffectedRow(int preAffectedRow) {
if (PREAffectedRow.get() == null) {
PREAffectedRow.set(new ArrayList<Integer>(1));
}
PREAffectedRow.get().add(preAffectedRow);
}
/**
* 包权限,构造ResultSet时,用这个方法提取预设数据
*/
static QueryResult popPreData() {
List<QueryResult> preData = PREDATA.get();
if (preData == null || preData.isEmpty()) {
return null; // 没有预设数据也是可以的,方便有些Test不care返回数据
}
return PREDATA.get().remove(0);
}
/**
* 包权限,构造ResultSet时,用这个方法提取预设数据
*/
static int popPreAffectedRow() {
List<Integer> preAffectedRow = PREAffectedRow.get();
if (preAffectedRow == null || preAffectedRow.isEmpty()) {
return 1; // 没有预设AffectedRow默认返回1
}
return PREAffectedRow.get().remove(0);
}
public static void addPreException(String key, SQLException e) {
PREException.get().get(key).add(e);
}
public static SQLException popPreException(String key) {
List<SQLException> pre = PREException.get().get(key);
return pre.size() == 0 ? null : pre.remove(0);
}
/**
* ==============================================================
* 以下为jdbc接口实现
* ==============================================================
*/
public Connection getConnection() throws SQLException {
try {
Thread.sleep(timeToObtainConnection);
} catch (Exception e) {
}
getConnectionInvokingTimes++;
return new MockConnection(m_getConnection, this);
}
public Connection getConnection(String username, String password) throws SQLException {
try {
Thread.sleep(timeToObtainConnection);
} catch (Exception e) {
}
getConnectionInvokingTimes++;
return new MockConnection("getConnection#username_password", this);
}
public PrintWriter getLogWriter() throws SQLException {
throw new NotSupportException("");
}
public int getLoginTimeout() throws SQLException {
throw new NotSupportException("");
}
public void setLogWriter(PrintWriter out) throws SQLException {
throw new NotSupportException("");
}
public void setLoginTimeout(int seconds) throws SQLException {
throw new NotSupportException("");
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getDbIndex() {
return dbIndex;
}
public void setDbIndex(String dbIndex) {
this.dbIndex = dbIndex;
}
@Override
public MockDataSource clone() throws CloneNotSupportedException {
return (MockDataSource) super.clone();
}
public boolean isClosed() {
return isClosed;
}
public void setClosed(boolean isClosed) {
this.isClosed = isClosed;
}
public int getGetConnectionInvokingTimes() {
return getConnectionInvokingTimes;
}
public void setGetConnectionInvokingTimes(int getConnectionInvokingTimes) {
this.getConnectionInvokingTimes = getConnectionInvokingTimes;
}
@Override
public String toString() {
return new StringBuilder(super.toString().substring(getClass().getPackage().getName().length() + 1)).append("{dbIndex:")
.append(dbIndex)
.append(",name:")
.append(name)
.append("}")
.toString();
}
public <T> T unwrap(Class<T> iface) throws SQLException {
return null;
}
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return false;
}
}