/**
* <pre>
* This program is free software; you can redistribute it and/or modify it under the terms of
* the GNU AFFERO GENERAL PUBLIC LICENSE as published by the Free Software Foundation; either version 3 of the License,
* or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU AFFERO GENERAL PUBLIC LICENSE for more details.
* You should have received a copy of the GNU AFFERO GENERAL PUBLIC LICENSE along with this program;
* if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
* </pre>
*/
package com.meidusa.amoeba.mysql.net;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.collections.map.LRUMap;
import org.apache.log4j.Logger;
import com.meidusa.amoeba.context.ProxyRuntimeContext;
import com.meidusa.amoeba.mysql.context.MysqlRuntimeContext;
import com.meidusa.amoeba.mysql.handler.MySqlCommandDispatcher;
import com.meidusa.amoeba.mysql.handler.PreparedStatmentInfo;
import com.meidusa.amoeba.mysql.io.MySqlPacketConstant;
import com.meidusa.amoeba.mysql.jdbc.MysqlDefs;
import com.meidusa.amoeba.mysql.net.packet.AuthenticationPacket;
import com.meidusa.amoeba.mysql.net.packet.ErrorPacket;
import com.meidusa.amoeba.mysql.net.packet.FieldPacket;
import com.meidusa.amoeba.mysql.net.packet.HandshakePacket;
import com.meidusa.amoeba.mysql.net.packet.MysqlPacketBuffer;
import com.meidusa.amoeba.mysql.net.packet.OkPacket;
import com.meidusa.amoeba.mysql.net.packet.QueryCommandPacket;
import com.meidusa.amoeba.mysql.net.packet.ResultSetHeaderPacket;
import com.meidusa.amoeba.mysql.net.packet.result.MysqlResultSetPacket;
import com.meidusa.amoeba.mysql.util.CharsetMapping;
import com.meidusa.amoeba.net.AuthResponseData;
import com.meidusa.amoeba.net.Connection;
import com.meidusa.amoeba.net.poolable.ObjectPool;
import com.meidusa.amoeba.net.poolable.PoolableObject;
import com.meidusa.amoeba.parser.ParseException;
import com.meidusa.amoeba.server.MultipleServerPool;
import com.meidusa.amoeba.util.StringUtil;
import com.meidusa.amoeba.util.ThreadLocalMap;
/**
* 负责连接到 proxy server的客户端连接对象包装
*
* @author <a href=mailto:piratebase@sina.com>Struct chen</a>
*/
public class MysqlClientConnection extends MysqlConnection implements MySqlPacketConstant{
private static Logger logger = Logger
.getLogger(MysqlClientConnection.class);
private static Logger authLogger = Logger.getLogger("auth");
private static Logger lastInsertID = Logger.getLogger("lastInsertId");
protected static byte[] AUTHENTICATEOKPACKETDATA;
static {
OkPacket ok = new OkPacket();
ok.packetId = 2;
ok.affectedRows = 0;
ok.insertId = 0;
ok.serverStatus = 2;
ok.warningCount = 0;
AUTHENTICATEOKPACKETDATA = ok.toByteBuffer(null).array();
}
private long createTime = System.currentTimeMillis();
public void afterAuth(){
if(authLogger.isDebugEnabled()){
authLogger.debug("authentication time:"+(System.currentTimeMillis()-createTime) +" Id="+this.getInetAddress());
}
}
// 保存服务端发送的随机用于客户端加密的字符串
protected String seed;
private long lastInsertId;
private int statementCacheSize = 500;
// 保存客户端返回的加密过的字符串
protected byte[] authenticationMessage;
private MultipleServerPool lastVirtualReadPool;
private ObjectPool lastReadRealPool;
public MysqlResultSetPacket lastPacketResult = new MysqlResultSetPacket(null);
{
lastPacketResult.resulthead = new ResultSetHeaderPacket();
lastPacketResult.resulthead.columns = 1;
lastPacketResult.fieldPackets = new FieldPacket[1];
FieldPacket field = new FieldPacket();
field.type = MysqlDefs.FIELD_TYPE_LONGLONG;
field.name = "last_insert_id";
field.catalog = "def";
field.length = 20;
lastPacketResult.fieldPackets[0] = field;
}
private List<byte[]> longDataList = new ArrayList<byte[]>();
private List<byte[]> unmodifiableLongDataList = Collections.unmodifiableList(longDataList);
/** 存储sql,statmentId对 */
private final Map<String, Long> sql_statment_id_map = Collections.synchronizedMap(new HashMap<String, Long>(256));
private AtomicLong atomicLong = new AtomicLong(1);
// 缓存前端连接使用的后端连接
private ConcurrentHashMap<ObjectPool, MysqlConnection> stickyConnMap;
/**
* 采用LRU缓存这些preparedStatment信息 key=statmentId value=PreparedStatmentInfo
* object
*/
@SuppressWarnings("unchecked")
private final Map<Long, PreparedStatmentInfo> prepared_statment_map = Collections
.synchronizedMap(new LRUMap(((MysqlRuntimeContext)ProxyRuntimeContext.getInstance().getRuntimeContext()).getStatementCacheSize()) {
private static final long serialVersionUID = 1L;
protected boolean removeLRU(LinkEntry entry) {
PreparedStatmentInfo info = (PreparedStatmentInfo) entry
.getValue();
sql_statment_id_map.remove(info.getSql());
return true;
}
public PreparedStatmentInfo remove(Object key) {
PreparedStatmentInfo info = (PreparedStatmentInfo) super
.remove(key);
sql_statment_id_map.remove(info.getSql());
return info;
}
public Object put(Object key, Object value) {
PreparedStatmentInfo info = (PreparedStatmentInfo) value;
sql_statment_id_map.put(info.getSql(),
(Long) key);
return super.put(key, value);
}
public void putAll(Map map) {
for (Iterator it = map.entrySet().iterator(); it.hasNext();) {
Map.Entry<Long, PreparedStatmentInfo> entry = (Map.Entry<Long, PreparedStatmentInfo>) it
.next();
sql_statment_id_map.put(entry.getValue()
.getSql(), entry.getKey());
}
super.putAll(map);
}
});
public MysqlClientConnection(SocketChannel channel, long createStamp) {
super(channel, createStamp);
this.stickyConnMap = new ConcurrentHashMap<ObjectPool, MysqlConnection>();
}
public PreparedStatmentInfo getPreparedStatmentInfo(long id) {
return prepared_statment_map.get(id);
}
public PreparedStatmentInfo getPreparedStatmentInfo(String preparedSql) throws ParseException{
Long id = sql_statment_id_map.get(preparedSql);
PreparedStatmentInfo info = null;
if (id == null) {
info = new PreparedStatmentInfo(this, atomicLong.getAndIncrement(),
preparedSql);
prepared_statment_map.put(info.getStatmentId(), info);
} else {
info = getPreparedStatmentInfo(id);
}
return info;
}
public PreparedStatmentInfo createStatementInfo(String preparedSql,List<byte[]> byts) throws ParseException{
Long id = sql_statment_id_map.get(preparedSql);
PreparedStatmentInfo info = null;
if (id == null) {
info = new PreparedStatmentInfo(this, atomicLong.getAndIncrement(),
preparedSql,byts);
prepared_statment_map.put(info.getStatmentId(), info);
} else {
info = getPreparedStatmentInfo(id);
}
return info;
}
public String getSeed() {
return seed;
}
public void setSeed(String seed) {
this.seed = seed;
}
public void handleMessage(Connection conn) {
byte[] message = this.getInQueue().getNonBlocking();
if(message != null){
// 在未验证通过的时候
/** 此时接收到的应该是认证数据,保存数据为认证提供数据 */
AuthenticationPacket autheticationPacket = new AuthenticationPacket();
autheticationPacket.init(message,conn);
this.getAuthenticator().authenticateConnection(this,autheticationPacket);
}
}
protected void beforeAuthing() {
HandshakePacket handshakePacket = new HandshakePacket();
handshakePacket.packetId = 0;
handshakePacket.protocolVersion = 0x0a;// 协议版本10
handshakePacket.seed = StringUtil.getRandomString(8);
handshakePacket.restOfScrambleBuff = StringUtil.getRandomString(12);
handshakePacket.serverStatus = 2;
handshakePacket.serverVersion = MysqlRuntimeContext.SERVER_VERSION;
//handshakePacket.serverCapabilities = 41516 & (~32);
handshakePacket.serverCapabilities = CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB
| CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION ;
MysqlRuntimeContext context = (MysqlRuntimeContext) ProxyRuntimeContext.getInstance().getRuntimeContext();
handshakePacket.serverCharsetIndex = (byte) (context.getServerCharsetIndex() & 0xff);
handshakePacket.threadId = Thread.currentThread().hashCode();
this.setSeed(handshakePacket.seed + handshakePacket.restOfScrambleBuff);
this.postMessage(handshakePacket.toByteBuffer(this).array());
}
protected void connectionAuthenticateSuccess(AuthResponseData data) {
super.connectionAuthenticateSuccess( data);
setMessageHandler(new MySqlCommandDispatcher());
postMessage(AUTHENTICATEOKPACKETDATA);
this.afterAuth();
}
protected void connectionAuthenticateFaild(AuthResponseData data) {
super.connectionAuthenticateFaild(data);
ErrorPacket error = new ErrorPacket();
error.resultPacketType = ErrorPacket.PACKET_TYPE_ERROR;
error.packetId = 2;
error.serverErrorMessage = data.message;
error.sqlstate = "42S02";
error.errno = 1000;
postMessage(error.toByteBuffer(this).array());
this.afterAuth();
}
protected void doReceiveMessage(byte[] message){
if(MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_QUIT)){
postClose(null);
return;
}else if(MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_STMT_CLOSE)){
//
return;
}else if(MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_PING)){
OkPacket ok = new OkPacket();
ok.affectedRows = 0;
ok.insertId = 0;
ok.packetId = 1;
ok.serverStatus = 2;
this.postMessage(ok.toByteBuffer(null).array());
return;
} else if (MysqlPacketBuffer.isPacketType(message, QueryCommandPacket.COM_STMT_SEND_LONG_DATA)) {
this.addLongData(message);
return;
}
super.doReceiveMessage(message);
}
protected void messageProcess() {
Executor executor = null;
if(isAuthenticatedSeted()){
executor = ProxyRuntimeContext.getInstance().getRuntimeContext().getClientSideExecutor();
}else{
executor = ProxyRuntimeContext.getInstance().getRuntimeContext().getServerSideExecutor();
}
executor.execute(new Runnable() {
public void run() {
try {
MysqlClientConnection.this.getMessageHandler().handleMessage(MysqlClientConnection.this);
} finally {
ThreadLocalMap.reset();
}
}
});
}
public void addLongData(byte[] longData) {
longDataList.add(longData);
}
public void clearLongData() {
longDataList.clear();
}
public List<byte[]> getLongDataList() {
return unmodifiableLongDataList;
}
public long getLastInsertId() {
if(lastInsertID.isDebugEnabled()){
lastInsertID.debug("get last_insert_Id="+lastInsertId);
}
return lastInsertId;
}
public void setLastInsertId(long lastInsertId) {
if(lastInsertID.isDebugEnabled()){
lastInsertID.debug("set last_insert_Id="+lastInsertId);
}
this.lastInsertId = lastInsertId;
}
public ConcurrentHashMap<ObjectPool, MysqlConnection> getStickyConnMap() {
return stickyConnMap;
}
/**
* 正在处于验证的Connection Idle时间可以设置相应的少一点。
*/
public boolean checkIdle(long now) {
if (isAuthenticated()) {
return false;
} else {
long idleMillis = now - _lastEvent;
if (idleMillis < 5000) {
return false;
}
if (isClosed()) {
return true;
}
return true;
}
}
public void bindTransaction(ObjectPool pool, MysqlConnection conn) {
conn.setCloseable(false);
stickyConnMap.putIfAbsent(pool, conn);
}
// 清除事务相关绑定
public void clearTransaction(boolean canRelease) {
int size = stickyConnMap.size();
this.setIsXaActive(false);
for(ObjectPool pool : stickyConnMap.keySet()) {
MysqlConnection conn = stickyConnMap.remove(pool);
if (conn != null) {
conn.setCloseable(true);
// 如果是事务模式,且commit/rollback成功,那么可以release
// 如果是非事务模式,那么总是可以release
if (canRelease) {
returnConnToPool(conn);
}
// 如果是事务模式,且commit/rollback失败,那么只能强行关闭后端连接了
else {
closeStickyServerConnection(conn);
}
}
}
String opt = canRelease?"return":"close";
if (logger.isInfoEnabled()) {
logger.info(String.format("%s %d sticky connections to conn pool", opt, size));
}
}
public void closeStickyServerConnection(MysqlConnection connection) {
if (connection != null) {
if (!connection.isClosed()) {
if (connection instanceof MysqlServerConnection) {
try {
((MysqlServerConnection) connection).close(null);
if (logger.isDebugEnabled()) {
logger.debug("close connection:" + connection);
}
} catch (Exception e) {
handleFailure(e);
logger.error(e);
}
}
}
}
}
public void returnConnToPool(MysqlConnection connection) {
if (connection != null) {
if (!connection.isClosed()) {
if (connection instanceof MysqlServerConnection) {
PoolableObject pooledObject = (PoolableObject) connection;
if (pooledObject.getObjectPool() != null && pooledObject.isActive()) {
try {
pooledObject.getObjectPool().returnObject(connection);
if (logger.isDebugEnabled()) {
logger.debug("connection:" + connection + " return to pool");
}
} catch (Exception e) {
// TODO handle exception
logger.error(e);
}
}
}
}
}
}
@Override
protected void close(Exception exception) {
clearTransaction(isAutoCommit());
super.close(exception);
}
/**
* 为了支持某些GUI客户端要求相邻的两条SQL必须发往同一个库
*
* 这个方法不解决非相邻的两条SQL必须发往同一个库,如需要,则需要Map解决
* 这个方法只针对读语句
*
* @param recentVirtualPool
* @return
* @throws Exception
*/
public ObjectPool getLastReadPool(ObjectPool pool)
throws Exception {
if (pool instanceof MultipleServerPool) {
MultipleServerPool recentVirtualPool = (MultipleServerPool)pool;
if (recentVirtualPool != null && !recentVirtualPool.equals(lastVirtualReadPool)) {
lastVirtualReadPool = recentVirtualPool;
lastReadRealPool = recentVirtualPool.selectPool();
} else {
if (lastReadRealPool == null || !lastReadRealPool.validate()) {
lastReadRealPool = recentVirtualPool.selectPool();
}
}
}
else {
if(pool != null && !pool.equals(lastReadRealPool)) {
lastReadRealPool = pool;
}
}
return lastReadRealPool;
}
@Override
public boolean setCharset(String charset) {
int ci = CharsetMapping.getCharsetIndex(charset);
if (ci > 0) {
this.charset = charset;
return true;
} else {
return false;
}
}
public int getStatementCacheSize() {
return statementCacheSize;
}
public void setStatementCacheSize(int statementCacheSize) {
this.statementCacheSize = statementCacheSize;
}
}