/*
* Copyright 2008-2108 amoeba.meidusa.com
*
* This program is free software; you can redistribute it and/or modify it under the terms of
* the GNU AFFERO GENERAL PUBLIC LICENSE as published by the Free Software Foundation; either version 3 of the License,
* or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU AFFERO GENERAL PUBLIC LICENSE for more details.
* You should have received a copy of the GNU AFFERO GENERAL PUBLIC LICENSE along with this program;
* if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
package com.meidusa.amoeba.mysql.handler;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import com.meidusa.amoeba.context.ProxyRuntimeContext;
import com.meidusa.amoeba.exception.AmoebaRuntimeException;
import com.meidusa.amoeba.mysql.context.MysqlRuntimeContext;
import com.meidusa.amoeba.mysql.jdbc.MysqlDefs;
import com.meidusa.amoeba.mysql.net.MysqlClientConnection;
import com.meidusa.amoeba.mysql.net.packet.BindValue;
import com.meidusa.amoeba.mysql.net.packet.ConstantPacketBuffer;
import com.meidusa.amoeba.mysql.net.packet.ErrorPacket;
import com.meidusa.amoeba.mysql.net.packet.FieldPacket;
import com.meidusa.amoeba.mysql.net.packet.MysqlPacketBuffer;
import com.meidusa.amoeba.mysql.net.packet.QueryCommandPacket;
import com.meidusa.amoeba.mysql.net.packet.ResultSetHeaderPacket;
import com.meidusa.amoeba.mysql.net.packet.RowDataPacket;
import com.meidusa.amoeba.mysql.net.packet.result.MysqlResultSetPacket;
import com.meidusa.amoeba.net.Connection;
import com.meidusa.amoeba.net.MessageHandler;
import com.meidusa.amoeba.net.poolable.ObjectPool;
import com.meidusa.amoeba.parser.dbobject.Column;
import com.meidusa.amoeba.parser.dbobject.GlobalSeqColumn;
import com.meidusa.amoeba.parser.dbobject.Schema;
import com.meidusa.amoeba.parser.expression.FunctionExpression;
import com.meidusa.amoeba.parser.statement.BeginStatement;
import com.meidusa.amoeba.parser.statement.CommitCMD;
import com.meidusa.amoeba.parser.statement.DMLStatement;
import com.meidusa.amoeba.parser.statement.HelpStatement;
import com.meidusa.amoeba.parser.statement.PropertyStatement;
import com.meidusa.amoeba.parser.statement.RollbackCMD;
import com.meidusa.amoeba.parser.statement.SelectStatement;
import com.meidusa.amoeba.parser.statement.StartTansactionStatement;
import com.meidusa.amoeba.parser.statement.Statement;
import com.meidusa.amoeba.parser.statement.XAStatement;
import com.meidusa.amoeba.parser.statement.ddl.DDLCreateSequenceStatenment;
import com.meidusa.amoeba.parser.statement.ddl.DDLDropSequenceStatement;
import com.meidusa.amoeba.route.SqlBaseQueryRouter;
import com.meidusa.amoeba.route.SqlQueryObject;
import com.meidusa.amoeba.seq.fetcher.SeqFetchService;
import com.meidusa.amoeba.seq.fetcher.SeqOperationResult;
import com.meidusa.amoeba.util.StringUtil;
/**
* handler
*
* @author <a href=mailto:piratebase@sina.com>Struct chen</a>
*/
public class MySqlCommandDispatcher implements MessageHandler {
protected static Logger logger = Logger.getLogger(MySqlCommandDispatcher.class);
private long timeout = ProxyRuntimeContext.getInstance().getRuntimeContext().getQueryTimeout() * 1000;
/**
* Ping 、COM_STMT_SEND_LONG_DATA command remove to @MysqlClientConnection #doReceiveMessage()
*/
public void handleMessage(Connection connection) {
byte[] message = null;
while((message = connection.getInQueue().getNonBlocking()) != null){
MysqlClientConnection conn = (MysqlClientConnection) connection;
QueryCommandPacket command = new QueryCommandPacket();
command.init(message, connection);
SqlBaseQueryRouter router = (SqlBaseQueryRouter)ProxyRuntimeContext.getInstance().getQueryRouter();
Statement statement = router.parseStatement(conn, command.query);
if (logger.isDebugEnabled()) {
logger.debug(command.query);
}
try {
if (MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_QUERY)) {
if(command.query != null && (command.query.indexOf("'$version'")>0 || command.query.indexOf("@amoebaversion")>0)){
MysqlResultSetPacket lastPacketResult = createAmoebaVersion(conn,(SelectStatement)statement,false);
lastPacketResult.wirteToConnection(conn);
return;
}
SqlQueryObject queryObject = new SqlQueryObject();
queryObject.isPrepared = false;
queryObject.sql = command.query;
// 属性配置语句
if (statement instanceof PropertyStatement) {
PropertyStatement st = (PropertyStatement) statement;
PropertyCommand propertyCommand = new PropertyCommand(timeout, conn);
propertyCommand.execute(st, queryObject);
}
// xa语句
else if (statement instanceof XAStatement && conn.isXaActive()) {
ErrorPacket error = new ErrorPacket();
error.errno = 1044;
error.packetId = 1;
error.sqlstate = "42000";
error.serverErrorMessage = "can not use xa statement in xa model";
conn.postMessage(error.toByteBuffer(connection).array());
logger.warn("can not use xa statement in xa model");
}
// begin 语句
else if (statement instanceof BeginStatement) {
conn.setAutoCommit(false);
conn.postMessage(ConstantPacketBuffer.STATIC_OK_BUFFER);
}
// start transaction 语句
else if (statement instanceof StartTansactionStatement) {
conn.setAutoCommit(false);
conn.postMessage(ConstantPacketBuffer.STATIC_OK_BUFFER);
}
// commit语句
else if (statement instanceof CommitCMD) {
CommitCommand commitCommand = new CommitCommand(timeout, conn);
commitCommand.execute(message, statement, queryObject);
}
// rollback语句
else if(statement instanceof RollbackCMD){
RollbackCommand rollbackCommand = new RollbackCommand(timeout, conn);
rollbackCommand.execute(message, statement, queryObject);
}
// 创建sequence语句
else if (statement instanceof DDLCreateSequenceStatenment) {
CreateSequence(conn, statement);
}
// 删除sequence语句
else if (statement instanceof DDLDropSequenceStatement) {
DropSequence(conn, statement);
}
// 其他查询语句
else {
ObjectPool[] pools = null;
// help 语句只要发一个default pool就好
if (statement instanceof HelpStatement) {
pools = router.getDefaultObjectPool();
if(pools != null && pools.length>1){
pools = new ObjectPool[]{pools[0]};
}
}
else {
pools = router.doRoute(conn, queryObject, statement);
}
/*
* 替换全局序列
* 而且只替换SELECT/INSERT/UPDATE/DELETE且不是Explain的语句
*/
if (statement instanceof DMLStatement) {
DMLStatement dmlStmt = (DMLStatement)statement;
List<GlobalSeqColumn<Column>> seqColumns = dmlStmt.getSeqColumns();
List<GlobalSeqColumn<FunctionExpression>> batchSeqFetchCalls = dmlStmt.getBatchFetchFuncCalls();
String targetSQL = queryObject.sql;
boolean isNeedReplace = false;
// 先替换 seq.nextval 或 seq.currval
if (seqColumns.size() > 0) {
isNeedReplace = true;
targetSQL = replaceSeqValue(conn, seqColumns, targetSQL, statement);
}
// 再替换批量获取的, seq.bulkval(count)
if (batchSeqFetchCalls.size() > 0) {
isNeedReplace = true;
targetSQL = replaceBatchSeqValue(conn, batchSeqFetchCalls, targetSQL, statement);
}
// 已替换过的sql,需要重新生成 byte[]
if (isNeedReplace) {
QueryCommandPacket autoCommitCommand = new QueryCommandPacket();
autoCommitCommand.query = targetSQL;
autoCommitCommand.command = QueryCommandPacket.COM_QUERY;
message = autoCommitCommand.toByteBuffer(conn).array();
queryObject.sql = targetSQL;
}
}
QueryCommand queryCommand = new QueryCommand(timeout, conn);
queryCommand.execute(pools, message, statement, queryObject);
}
}
else if (MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_STMT_PREPARE)) {
SqlQueryObject queryObject = new SqlQueryObject();
queryObject.isPrepared = true;
queryObject.sql = command.query;
ObjectPool[] pools = router.doRoute(conn, queryObject, statement);
PrepareQueryCommand prepareQueryCommand = new PrepareQueryCommand(timeout, conn);
prepareQueryCommand.execute(pools, message, statement, queryObject, command);
}
else if (MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_STMT_EXECUTE)) {
PrepareExecuteCommand prepareExecuteCommand = new PrepareExecuteCommand(timeout, conn, router);
prepareExecuteCommand.execute(message, statement);
}
else if (MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_INIT_DB)) {
conn.setSchema(command.query);
conn.postMessage(ConstantPacketBuffer.STATIC_OK_BUFFER);
}
else if (MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_CHANGE_USER)){
conn.postMessage(ConstantPacketBuffer.STATIC_OK_BUFFER);
}
else{
ErrorPacket error = new ErrorPacket();
error.errno = 1044;
error.packetId = 1;
error.sqlstate = "42000";
error.serverErrorMessage = "can not use this command here!!";
conn.postMessage(error.toByteBuffer(connection).array());
logger.warn("unsupport packet:" + command);
}
} catch (Exception e) {
ErrorPacket error = new ErrorPacket();
error.errno = 1044;
error.packetId = 1;
error.sqlstate = "42000";
error.serverErrorMessage = e.getMessage();
conn.postMessage(error.toByteBuffer(connection).array());
logger.error("messageDispate error", e);
}
}
}
private String replaceSeqValue(MysqlClientConnection conn, List<GlobalSeqColumn<Column>> seqColumns, String sourceSql, Statement stmt) {
StringBuffer replacedSQL = new StringBuffer(sourceSql);
/*
* 主要为了兼容一些客户端可能会为了获取执行计划,查询函数性能
* 但是我们的序列是自己生成的,所以只能通过模拟一个执行计划来返回给客户端
*/
if(stmt.isExplain()) {
replacedSQL = new StringBuffer("EXPLAIN EXTENDED SELECT 1 FROM DUAL");
}
else {
for(GlobalSeqColumn<Column> globalSeq : seqColumns) {
Column column = globalSeq.getSeqObject();
String seqName = globalSeq.getSeqName();
Long seqVal = globalSeq.getSeqValue();
int endIndex = replacedSQL.length() - sourceSql.length() + globalSeq.getSeqTokenEndColumn();
int startIndex = endIndex - seqName.length();
if (seqVal < 0) {
logger.error(String.format("skip to replace global sequence %s since its value is negative number", seqName));
}
else {
StringBuilder replaceColumn = new StringBuilder();
replaceColumn.append(seqVal+"");
// 通过增加别名来避免别名丢失
if (StringUtil.isEmpty(column.getAlias()) && (stmt instanceof SelectStatement)) {
replaceColumn.append(" as " + column.getName());
}
if (startIndex >= 0 && endIndex >= startIndex) {
try {
replacedSQL.replace(startIndex, endIndex, replaceColumn.toString());
} catch (Exception e) {
logger.error(String.format("error occours when replace global sequence name with value since: %s", e.getMessage()));
}
}
}
}
}
return replacedSQL.toString();
}
private String replaceBatchSeqValue(MysqlClientConnection conn, List<GlobalSeqColumn<FunctionExpression>> batchSeqFetchCalls, String sourceSql, Statement stmt) {
StringBuffer replacedSQL = new StringBuffer(sourceSql);
/*
* 主要为了兼容一些客户端可能会为了获取执行计划,查询函数性能
* 但是我们的序列是自己生成的,所以只能通过模拟一个执行计划来返回给客户端
*/
if(stmt.isExplain()) {
replacedSQL = new StringBuffer("EXPLAIN EXTENDED SELECT 1 FROM DUAL");
}
else {
for(GlobalSeqColumn<FunctionExpression> bulkFunCall: batchSeqFetchCalls) {
// 函数表达式
FunctionExpression funExp = bulkFunCall.getSeqObject();
Long seqVal = bulkFunCall.getSeqValue();
// 函数名
String bulkFuncName = bulkFunCall.getSeqName();
int endIndex = replacedSQL.length() - sourceSql.length() + bulkFunCall.getSeqTokenEndColumn();
int startIndex = endIndex - funExp.toString().length();
if (seqVal > 0) {
StringBuilder replaceColumn = new StringBuilder();
replaceColumn.append(seqVal+"");
// 通过增加别名来避免别名丢失
if (stmt instanceof SelectStatement) {
replaceColumn.append(" as " + bulkFuncName);
}
if (startIndex >= 0 && endIndex >= startIndex) {
try {
replacedSQL = replacedSQL.replace(startIndex, endIndex, replaceColumn.toString());
} catch (Exception e) {
logger.error(String.format("error occours when replace bulkval function with value since: %s", e.getMessage()));
}
}
}
}
}
return replacedSQL.toString();
}
private void DropSequence(MysqlClientConnection conn, Statement statement) {
DDLDropSequenceStatement seqStmt = (DDLDropSequenceStatement)statement;
Schema schema = seqStmt.getSchema();
if (schema == null || StringUtil.isEmpty(schema.getName())) {
if (StringUtil.isEmpty(conn.getSchema())) {
throw new AmoebaRuntimeException("can not delete seq since schema is null");
}
else {
schema = new Schema();
schema.setName(conn.getSchema());
}
}
String schemaName = schema.getName();
String seqName = seqStmt.getSeqName();
SeqOperationResult result = SeqFetchService.deleteSeq(schemaName, seqName);
if (result.isSuccessed()) {
conn.postMessage(ConstantPacketBuffer.STATIC_OK_BUFFER);
}
else {
sendInternalErrorMsg(conn, result.getErrMsg());
}
}
private void CreateSequence(MysqlClientConnection conn, Statement statement) {
DDLCreateSequenceStatenment seqStmt = (DDLCreateSequenceStatenment)statement;
Schema schema = seqStmt.getSchema();
if (schema == null || StringUtil.isEmpty(schema.getName())) {
if (StringUtil.isEmpty(conn.getSchema())) {
throw new AmoebaRuntimeException("can not create seq since schema is null");
}
else {
schema = new Schema();
schema.setName(conn.getSchema());
}
}
String schemaName = schema.getName();
String seqName = seqStmt.getSeqName();
long start = seqStmt.getStartWith();
long offset = seqStmt.getOffset();
SeqOperationResult result = SeqFetchService.createSeq(schemaName, seqName, start, offset);
if (result.isSuccessed()) {
conn.postMessage(ConstantPacketBuffer.STATIC_OK_BUFFER);
}
else {
sendInternalErrorMsg(conn, result.getErrMsg());
}
}
private MysqlResultSetPacket createAmoebaVersion(MysqlClientConnection conn,SelectStatement statment,boolean isPrepared){
Map<String,Column> selectedMap = ((SelectStatement)statment).getSelectColumnMap();
MysqlResultSetPacket lastPacketResult = new MysqlResultSetPacket(null);
lastPacketResult.resulthead = new ResultSetHeaderPacket();
lastPacketResult.resulthead.columns = (selectedMap.size()==0?1:selectedMap.size());
if(selectedMap.size() == 0){
Column column = new Column();
column.setName("@amoebaversion");
selectedMap.put("@amoebaversion", column);
}
lastPacketResult.resulthead.extra = 1;
RowDataPacket row = new RowDataPacket(isPrepared);
row.columns = new ArrayList<Object>();
int index =0;
lastPacketResult.fieldPackets = new FieldPacket[selectedMap.size()];
for(Map.Entry<String, Column> entry : selectedMap.entrySet()){
FieldPacket field = new FieldPacket();
String alias = entry.getValue().getAlias();
if("@amoebaversion".equalsIgnoreCase(entry.getValue().getName())
|| "'$version'".equalsIgnoreCase(entry.getValue().getName())){
BindValue value = new BindValue();
value.bufferType = MysqlDefs.FIELD_TYPE_VARCHAR;
value.value = MysqlRuntimeContext.SERVER_VERSION;
value.scale = 20;
value.isSet = true;
row.columns.add(value);
field.name = (alias == null?entry.getValue().getName()+"()":alias);
}else{
BindValue value = new BindValue();
value.bufferType = MysqlDefs.FIELD_TYPE_VARCHAR;
value.scale = 20;
value.isNull = true;
row.columns.add(value);
field.name = (alias == null?entry.getValue().getName():alias);
}
field.type = (byte)MysqlDefs.FIELD_TYPE_VARCHAR;
field.catalog = "def";
field.length = 20;
lastPacketResult.fieldPackets[index] = field;
index++;
}
List<RowDataPacket> list = new ArrayList<RowDataPacket>();
list.add(row);
lastPacketResult.setRowList(list);
return lastPacketResult;
}
private void sendInternalErrorMsg(MysqlClientConnection conn, String msg) {
ErrorPacket error = new ErrorPacket();
error.errno = 1044;
error.packetId = 1;
error.sqlstate = "42000";
error.serverErrorMessage = msg;
conn.postMessage(error.toByteBuffer(conn).array());
}
public static MysqlResultSetPacket createLastInsertIdPacket(MysqlClientConnection conn,SelectStatement statment,boolean isPrepared){
Map<String,Column> selectedMap = ((SelectStatement)statment).getSelectColumnMap();
MysqlResultSetPacket lastPacketResult = new MysqlResultSetPacket(null);
lastPacketResult.resulthead = new ResultSetHeaderPacket();
lastPacketResult.resulthead.columns = selectedMap.size();
lastPacketResult.resulthead.extra = 1;
RowDataPacket row = new RowDataPacket(isPrepared);
row.columns = new ArrayList<Object>();
int index =0;
lastPacketResult.fieldPackets = new FieldPacket[selectedMap.size()];
for(Map.Entry<String, Column> entry : selectedMap.entrySet()){
FieldPacket field = new FieldPacket();
String alias = entry.getValue().getAlias();
if("LAST_INSERT_ID".equalsIgnoreCase(entry.getValue().getName())){
BindValue value = new BindValue();
value.bufferType = MysqlDefs.FIELD_TYPE_LONGLONG;
value.longBinding = conn.getLastInsertId();
value.scale = 20;
value.isSet = true;
row.columns.add(value);
field.name = (alias == null?entry.getValue().getName()+"()":alias);
}else if("@@IDENTITY".equalsIgnoreCase(entry.getValue().getName())){
BindValue value = new BindValue();
value.bufferType = MysqlDefs.FIELD_TYPE_LONGLONG;
value.longBinding = conn.getLastInsertId();
value.scale = 20;
value.isSet = true;
row.columns.add(value);
row.columns.add(value);
field.name = (alias == null?entry.getValue().getName():alias);
}else{
BindValue value = new BindValue();
value.bufferType = MysqlDefs.FIELD_TYPE_STRING;
value.scale = 20;
value.isNull = true;
row.columns.add(value);
field.name = (alias == null?entry.getValue().getName():alias);
}
field.type = MysqlDefs.FIELD_TYPE_LONGLONG;
field.catalog = "def";
field.length = 20;
lastPacketResult.fieldPackets[index] = field;
index++;
}
List<RowDataPacket> list = new ArrayList<RowDataPacket>();
list.add(row);
lastPacketResult.setRowList(list);
return lastPacketResult;
}
}