package io.mycat.net;
import io.mycat.util.TimeUtil;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author wuzh
*/
public abstract class Connection implements ClosableConnection{
public static Logger LOGGER = LoggerFactory.getLogger(Connection.class);
protected String host;
protected int port;
protected int localPort;
protected long id;
public enum State {
connecting, connected, closing, closed, failed
}
private State state = State.connecting;
// 连接的方向,in表示是客户端连接过来的,out表示自己作为客户端去连接对端Sever
public enum Direction {
in, out
}
private Direction direction = Direction.in;
protected final SocketChannel channel;
private SelectionKey processKey;
private static final int OP_NOT_READ = ~SelectionKey.OP_READ;
private static final int OP_NOT_WRITE = ~SelectionKey.OP_WRITE;
private ByteBuffer readBuffer;
private ByteBuffer writeBuffer;
private final ConcurrentLinkedQueue<ByteBuffer> writeQueue = new ConcurrentLinkedQueue<ByteBuffer>();
private final ReentrantLock writeQueueLock = new ReentrantLock();
private int readBufferOffset;
private long lastLargeMessageTime;
protected boolean isClosed;
protected boolean isSocketClosed;
protected long startupTime;
protected long lastReadTime;
protected long lastWriteTime;
protected int netInBytes;
protected int netOutBytes;
protected int pkgTotalSize;
protected int pkgTotalCount;
private long idleTimeout;
private long lastPerfCollectTime;
@SuppressWarnings("rawtypes")
protected NIOHandler handler;
private int maxPacketSize;
private int packetHeaderSize;
public Connection(SocketChannel channel) {
this.channel = channel;
this.isClosed = false;
this.startupTime = TimeUtil.currentTimeMillis();
this.lastReadTime = startupTime;
this.lastWriteTime = startupTime;
this.lastPerfCollectTime = startupTime;
}
public void resetPerfCollectTime() {
netInBytes = 0;
netOutBytes = 0;
pkgTotalCount = 0;
pkgTotalSize = 0;
lastPerfCollectTime = TimeUtil.currentTimeMillis();
}
public long getLastPerfCollectTime() {
return lastPerfCollectTime;
}
public long getIdleTimeout() {
return idleTimeout;
}
public void setIdleTimeout(long idleTimeout) {
this.idleTimeout = idleTimeout;
}
public String getHost() {
return host;
}
public void setHost(String host) {
this.host = host;
}
public int getPort() {
return port;
}
public void setPort(int port) {
this.port = port;
}
public long getId() {
return id;
}
public int getLocalPort() {
return localPort;
}
public void setLocalPort(int localPort) {
this.localPort = localPort;
}
public void setId(long id) {
this.id = id;
}
public boolean isIdleTimeout() {
return TimeUtil.currentTimeMillis() > Math.max(lastWriteTime,
lastReadTime) + idleTimeout;
}
public SocketChannel getChannel() {
return channel;
}
public long getStartupTime() {
return startupTime;
}
public long getLastReadTime() {
return lastReadTime;
}
public long getLastWriteTime() {
return lastWriteTime;
}
public long getNetInBytes() {
return netInBytes;
}
public long getNetOutBytes() {
return netOutBytes;
}
public ByteBuffer getReadBuffer() {
return readBuffer;
}
private ByteBuffer allocate() {
return NetSystem.getInstance().getBufferPool().allocate();
}
private final void recycle(ByteBuffer buffer) {
NetSystem.getInstance().getBufferPool().recycle(buffer);
}
public void setHandler(NIOHandler<? extends Connection> handler) {
this.handler = handler;
}
@SuppressWarnings("rawtypes")
public NIOHandler getHandler() {
return this.handler;
}
@SuppressWarnings("unchecked")
public void handle(final ByteBuffer data, final int start,
final int readedLength) {
handler.handle(this, data, start, readedLength);
}
/**
* 读取可能的Socket字节流
*
* @param got
* @throws IOException
*/
public void onReadData(int got) throws IOException {
if (isClosed) {
return;
}
lastReadTime = TimeUtil.currentTimeMillis();
if (got < 0) {
this.close("stream closed");
return;
} else if (got == 0) {
if (!this.channel.isOpen()) {
this.close("socket closed");
return;
}
}
netInBytes += got;
// System.out.println("readed new size "+got);
NetSystem.getInstance().addNetInBytes(got);
// 循环处理字节信息
int offset = readBufferOffset, length = 0, position = readBuffer.position();
while(readBuffer != null && !isClosed) {
length = getPacketLength(readBuffer, offset, position);
// LOGGER.info("message lenth "+length+" offset "+offset+" positon "+position+" capactiy "+readBuffer.capacity());
// System.out.println("message lenth "+length+" offset "+offset+" positon "+position);
if (length == -1) {
if (offset != 0) {
this.readBuffer = compactReadBuffer(readBuffer, offset);
} else if (readBuffer != null && !readBuffer.hasRemaining()) {
throw new RuntimeException(
"invalid readbuffer capacity ,too little buffer size "
+ readBuffer.capacity());
}
break;
}
pkgTotalCount++;
pkgTotalSize += length;
// check if a complete message packge received
if (offset + length <= position && readBuffer != null) {
// handle this package
readBuffer.position(offset);
handle(readBuffer, offset, length);
// maybe handle stmt_close
if(isClosed()) {
return ;
}
// offset to next position
offset += length;
// reached end
if (position == offset) {
// if cur buffer is temper none direct byte buffer and not
// received large message in recent 30 seconds
// then change to direct buffer for performance
if (readBuffer != null && !readBuffer.isDirect()
&& lastLargeMessageTime < lastReadTime - 30 * 1000L) {// used
// temp
// heap
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("change to direct con read buffer ,cur temp buf size :"
+ readBuffer.capacity());
}
recycle(readBuffer);
readBuffer = NetSystem.getInstance().getBufferPool()
.allocateConReadBuffer();
} else {
if (readBuffer != null)
readBuffer.clear();
}
// no more data ,break
readBufferOffset = 0;
break;
} else {
// try next package parse
readBufferOffset = offset;
if(readBuffer != null)
readBuffer.position(position);
continue;
}
} else {
// not read whole message package ,so check if buffer enough and
// compact readbuffer
if (!readBuffer.hasRemaining()) {
readBuffer = ensureFreeSpaceOfReadBuffer(readBuffer,
offset, length);
}
break;
}
}
}
public boolean isConnected() {
return (this.state == Connection.State.connected);
}
private boolean isConReadBuffer(ByteBuffer buffer) {
return buffer.capacity() == NetSystem.getInstance().getBufferPool()
.getConReadBuferChunk()
&& buffer.isDirect();
}
private ByteBuffer ensureFreeSpaceOfReadBuffer(ByteBuffer buffer,
int offset, final int pkgLength) {
// need a large buffer to hold the package
if (pkgLength > maxPacketSize) {
throw new IllegalArgumentException("Packet size over the limit.");
} else if (buffer.capacity() < pkgLength) {
ByteBuffer newBuffer = NetSystem.getInstance().getBufferPool()
.allocate(pkgLength);
lastLargeMessageTime = TimeUtil.currentTimeMillis();
buffer.position(offset);
newBuffer.put(buffer);
readBuffer = newBuffer;
if (isConReadBuffer(buffer)) {
NetSystem.getInstance().getBufferPool()
.recycleConReadBuffer(buffer);
} else {
recycle(buffer);
}
readBufferOffset = 0;
return newBuffer;
} else {
if (offset != 0) {
// compact bytebuffer only
return compactReadBuffer(buffer, offset);
} else {
throw new RuntimeException(" not enough space");
}
}
}
private ByteBuffer compactReadBuffer(ByteBuffer buffer, int offset) {
if(buffer == null) return null;
buffer.limit(buffer.position());
buffer.position(offset);
buffer = buffer.compact();
readBufferOffset = 0;
return buffer;
}
public void write(byte[] src) {
try {
writeQueueLock.lock();
ByteBuffer buffer = this.allocate();
int offset = 0;
int remains = src.length;
while (remains > 0) {
int writeable = buffer.remaining();
if (writeable >= remains) {
// can write whole srce
buffer.put(src, offset, remains);
this.writeQueue.offer(buffer);
break;
} else {
// can write partly
buffer.put(src, offset, writeable);
offset += writeable;
remains -= writeable;
writeQueue.offer(buffer);
buffer = allocate();
continue;
}
}
} finally {
writeQueueLock.unlock();
}
this.enableWrite(true);
}
/**
* note only use this method when the input buffer is shared
*
* @param buffer
* @param from
* @param lenth
*/
public final void write(ByteBuffer buffer, int from, int lenth) {
try {
writeQueueLock.lock();
buffer.position(from);
int remainByts = lenth;
while (remainByts > 0) {
ByteBuffer newBuf = allocate();
int batchSize = newBuf.capacity();
for (int i = 0; i < batchSize & remainByts > 0; i++) {
newBuf.put(buffer.get());
remainByts--;
}
writeQueue.offer(newBuf);
}
} finally {
writeQueueLock.unlock();
}
this.enableWrite(true);
}
public final void write(ByteBuffer buffer) {
try {
writeQueueLock.lock();
writeQueue.offer(buffer);
} finally {
writeQueueLock.unlock();
}
this.enableWrite(true);
}
@SuppressWarnings("unchecked")
public void close(String reason) {
if (!isClosed) {
closeSocket();
this.cleanup();
isClosed = true;
NetSystem.getInstance().removeConnection(this);
LOGGER.info("close connection,reason:" + reason + " ," + this);
if (handler != null) {
handler.onClosed(this, reason);
}
}
}
/**
* asyn close (executed later in thread)
* 该函数使用多线程异步关闭 Connection,会存在并发安全问题,暂时注释
* @param reason
*/
// public void asynClose(final String reason) {
// Runnable runn = new Runnable() {
// public void run() {
// Connection.this.close(reason);
// }
// };
// NetSystem.getInstance().getTimer().schedule(runn, 1, TimeUnit.SECONDS);
//
// }
public boolean isClosed() {
return isClosed;
}
public void idleCheck() {
if (isIdleTimeout()) {
LOGGER.info(toString() + " idle timeout");
close(" idle ");
}
}
/**
* 清理资源
*/
protected void cleanup() {
// 清理资源占用
if (readBuffer != null) {
if (isConReadBuffer(readBuffer)) {
NetSystem.getInstance().getBufferPool()
.recycleConReadBuffer(readBuffer);
} else {
this.recycle(readBuffer);
}
this.readBuffer = null;
this.readBufferOffset = 0;
}
if (writeBuffer != null) {
recycle(writeBuffer);
this.writeBuffer = null;
}
ByteBuffer buffer = null;
while ((buffer = writeQueue.poll()) != null) {
recycle(buffer);
}
}
protected final int getPacketLength(ByteBuffer buffer, int offset,
int position) {
if (position < offset + packetHeaderSize) {
return -1;
} else {
int length = buffer.get(offset) & 0xff;
length |= (buffer.get(++offset) & 0xff) << 8;
length |= (buffer.get(++offset) & 0xff) << 16;
return length + packetHeaderSize;
}
}
public ConcurrentLinkedQueue<ByteBuffer> getWriteQueue() {
return writeQueue;
}
@SuppressWarnings("unchecked")
public void register(Selector selector) throws IOException {
processKey = channel.register(selector, SelectionKey.OP_READ, this);
NetSystem.getInstance().addConnection(this);
readBuffer = NetSystem.getInstance().getBufferPool()
.allocateConReadBuffer();
this.handler.onConnected(this);
}
public void doWriteQueue() {
try {
boolean noMoreData = write0();
lastWriteTime = TimeUtil.currentTimeMillis();
if (noMoreData && writeQueue.isEmpty()) {
if ((processKey.isValid() && (processKey.interestOps() & SelectionKey.OP_WRITE) != 0)) {
disableWrite();
}
} else {
if ((processKey.isValid() && (processKey.interestOps() & SelectionKey.OP_WRITE) == 0)) {
enableWrite(false);
}
}
} catch (IOException e) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("caught err:", e);
}
close("err:" + e);
}
}
public void write(BufferArray bufferArray) {
try {
writeQueueLock.lock();
List<ByteBuffer> blockes = bufferArray.getWritedBlockLst();
if (!bufferArray.getWritedBlockLst().isEmpty()) {
for (ByteBuffer curBuf : blockes) {
writeQueue.offer(curBuf);
}
}
ByteBuffer curBuf = bufferArray.getCurWritingBlock();
if (curBuf.position() == 0) {// empty
this.recycle(curBuf);
} else {
writeQueue.offer(curBuf);
}
} finally {
writeQueueLock.unlock();
bufferArray.clear();
}
this.enableWrite(true);
}
private boolean write0() throws IOException {
int written = 0;
ByteBuffer buffer = writeBuffer;
if (buffer != null) {
while (buffer.hasRemaining()) {
written = channel.write(buffer);
if (written > 0) {
netOutBytes += written;
NetSystem.getInstance().addNetOutBytes(written);
lastWriteTime = TimeUtil.currentTimeMillis();
} else {
break;
}
}
if (buffer.hasRemaining()) {
return false;
} else {
writeBuffer = null;
recycle(buffer);
}
}
while ((buffer = writeQueue.poll()) != null) {
if (buffer.limit() == 0) {
recycle(buffer);
close("quit send");
return true;
}
buffer.flip();
while (buffer.hasRemaining()) {
written = channel.write(buffer);
if (written > 0) {
netOutBytes += written;
NetSystem.getInstance().addNetOutBytes(written);
lastWriteTime = TimeUtil.currentTimeMillis();
} else {
break;
}
}
if (buffer.hasRemaining()) {
writeBuffer = buffer;
return false;
} else {
recycle(buffer);
}
}
return true;
}
private void disableWrite() {
try {
SelectionKey key = this.processKey;
key.interestOps(key.interestOps() & OP_NOT_WRITE);
} catch (Exception e) {
LOGGER.warn("can't disable write " + e + " con " + this);
}
}
public void enableWrite(boolean wakeup) {
boolean needWakeup = false;
try {
SelectionKey key = this.processKey;
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
needWakeup = true;
} catch (Exception e) {
LOGGER.warn("can't enable write " + e);
}
if (needWakeup && wakeup) {
processKey.selector().wakeup();
}
}
public void disableRead() {
SelectionKey key = this.processKey;
key.interestOps(key.interestOps() & OP_NOT_READ);
}
public void enableRead() {
boolean needWakeup = false;
try {
SelectionKey key = this.processKey;
key.interestOps(key.interestOps() | SelectionKey.OP_READ);
needWakeup = true;
} catch (Exception e) {
LOGGER.warn("enable read fail " + e);
}
if (needWakeup) {
processKey.selector().wakeup();
}
}
public void setState(State newState) {
this.state = newState;
}
/**
* 异步读取数据,only nio thread call
*
* @throws IOException
*/
protected void asynRead() throws IOException {
if (this.isClosed) {
return;
}
int got = channel.read(readBuffer);
onReadData(got);
}
private void closeSocket() {
if (channel != null) {
boolean isSocketClosed = true;
try {
processKey.cancel();
channel.close();
} catch (Throwable e) {
}
boolean closed = isSocketClosed && (!channel.isOpen());
if (!closed) {
LOGGER.warn("close socket of connnection failed " + this);
}
}
}
public State getState() {
return state;
}
public Direction getDirection() {
return direction;
}
public void setDirection(Connection.Direction in) {
this.direction = in;
}
public int getPkgTotalSize() {
return pkgTotalSize;
}
public int getPkgTotalCount() {
return pkgTotalCount;
}
@Override
public String toString() {
return "Connection [host=" + host + ", port=" + port + ", id=" + id
+ ", state=" + state + ", direction=" + direction
+ ", startupTime=" + startupTime + ", lastReadTime="
+ lastReadTime + ", lastWriteTime=" + lastWriteTime + "]";
}
public void setMaxPacketSize(int maxPacketSize) {
this.maxPacketSize = maxPacketSize;
}
public void setPacketHeaderSize(int packetHeaderSize) {
this.packetHeaderSize = packetHeaderSize;
}
}