/**
*
*/
package io.nettythrift.codec;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.nettythrift.core.*;
import org.apache.thrift.protocol.*;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
/**
* 负责把用户请求的thrift协议内容解析为一个ThriftMessage 对象
*
* @author HouKx
*/
public class ThriftMessageDecoder extends ByteToMessageDecoder {
private static Logger logger = LoggerFactory.getLogger(ThriftMessageDecoder.class);
public static final int MESSAGE_FRAME_SIZE = 4;
private final ThriftServerDef serverDef;
private final int maxFrameSize;
public ThriftMessageDecoder(ThriftServerDef serverDef) {
this.serverDef = serverDef;
maxFrameSize = serverDef.maxFrameSize;
}
private ThriftMessageWrapper successor;
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ThriftMessageWrapper) {
successor = (ThriftMessageWrapper) evt;
} else {
super.userEventTriggered(ctx, evt);
}
}
/*
* (non-Javadoc)
*
* @see io.netty.handler.codec.ByteToMessageDecoder#decode(io.netty.channel.
* ChannelHandlerContext, io.netty.buffer.ByteBuf, java.util.List)
*/
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (ctx.channel().isActive()) {
ThriftMessage msg = decodeMessage(ctx, in);
logger.debug("decodeMessage() return:{}", msg);
if (msg != null) {
msg.getContent().retain();
out.add(msg);
}
}
}
protected ThriftMessage decodeMessage(ChannelHandlerContext ctx, ByteBuf buffer)
throws Exception {
TProtocolFactory inputProtocolFactory = getProtocolFactory(buffer);
if (inputProtocolFactory != null) {
ByteBuf messageBuffer = tryDecodeUnframedMessage(ctx, ctx.channel(), buffer, inputProtocolFactory);
if (messageBuffer == null) {
return null;
}
// A non-zero MSB for the first byte of the message implies the
// message starts with a
// protocol id (and thus it is unframed).
return new ThriftMessage(messageBuffer, inputProtocolFactory).setWrapper(unframedMessageWrapper(successor));
} else if (buffer.readableBytes() < MESSAGE_FRAME_SIZE) {
// Expecting a framed message, but not enough bytes available to
// read the frame size
return null;
} else {
ByteBuf messageBuffer = tryDecodeFramedMessage(ctx, ctx.channel(), buffer);
if (messageBuffer == null) {
return null;
}
inputProtocolFactory = getProtocolFactory(messageBuffer);
logger.debug("get inputProtocolFactory [{}] for frameRequest", inputProtocolFactory);
// connectionContext.setAttribute(ServerDef.KEY_PROTOCOL_FACTORY,
// inputProtocolFactory);
// Messages with a zero MSB in the first byte are framed messages
return new ThriftMessage(messageBuffer, inputProtocolFactory).setWrapper(framedMessageWrapper(successor));
}
}
private TProtocolFactory getProtocolFactory(ByteBuf buffer) {
short headCode = buffer.getShort(buffer.readerIndex());
if (headCode != 0 && buffer.readableBytes() > MESSAGE_FRAME_SIZE) {
return serverDef.protocolFactorySelector.getProtocolFactory(headCode);
}
return null;
}
protected ByteBuf tryDecodeFramedMessage(ChannelHandlerContext ctx, Channel channel, ByteBuf buffer) {
// Framed messages are prefixed by the size of the frame (which doesn't
// include the
// framing itself).
int messageStartReaderIndex = buffer.readerIndex();
int messageContentsOffset;
// if (stripFraming) {
messageContentsOffset = messageStartReaderIndex + MESSAGE_FRAME_SIZE;
// } else {
// messageContentsOffset = messageStartReaderIndex;
// }
// The full message is larger by the size of the frame size prefix
int messageLength = buffer.getInt(messageStartReaderIndex) + MESSAGE_FRAME_SIZE;
int messageContentsLength = messageStartReaderIndex + messageLength - messageContentsOffset;
logger.debug("messageLength={}, rIndex={}, offset={}, readableBytes={}", messageLength, messageStartReaderIndex,
messageContentsOffset, buffer.readableBytes());
if (messageContentsLength > maxFrameSize) {
throw new TooLongFrameException(
String.format("Frame size exceeded on encode: frame was %d bytes, maximum allowed is %d bytes",
messageLength, maxFrameSize));
}
if (messageLength == 0) {
// Zero-sized frame: just ignore it and return nothing
buffer.readerIndex(messageContentsOffset);
return null;
} else if (buffer.readableBytes() < messageLength) {
// Full message isn't available yet, return nothing for now
return null;
} else {
// Full message is available, return it
ByteBuf messageBuffer = extractFrame(buffer, messageContentsOffset, messageContentsLength);
buffer.readerIndex(messageStartReaderIndex + messageLength);
return messageBuffer;
}
}
protected ByteBuf tryDecodeUnframedMessage(ChannelHandlerContext ctx, Channel channel, ByteBuf buffer,
TProtocolFactory inputProtocolFactory) throws Exception {
// Perform a trial decode, skipping through
// the fields, to see whether we have an entire message available.
int messageLength = 0;
// record original readerIndex
final int messageStartReaderIndex = buffer.readerIndex();
try {
TNettyTransport decodeAttemptTransport = new TNettyTransport(channel, buffer);
int initialReadBytes = decodeAttemptTransport.getReadByteCount();
TProtocol inputProtocol = inputProtocolFactory.getProtocol(decodeAttemptTransport);
// Skip through the message
inputProtocol.readMessageBegin();
TProtocolUtil.skip(inputProtocol, TType.STRUCT);
inputProtocol.readMessageEnd();
messageLength = decodeAttemptTransport.getReadByteCount() - initialReadBytes;
} catch (TTransportException e) {
// No complete message was decoded: ran out of bytes
return null;
} catch (IndexOutOfBoundsException e) {
// No complete message was decoded: ran out of bytes
return null;
} catch (TProtocolException e) {
// No complete message was decoded: ran out of bytes
return null;
} finally {
if (buffer.readerIndex() - messageStartReaderIndex > maxFrameSize) {
throw new TooLongFrameException("Maximum frame size of " + maxFrameSize + " exceeded");
}
// reset messageStartReaderIndex to original
buffer.readerIndex(messageStartReaderIndex);
}
if (messageLength <= 0) {
// System.out.println("** 消息长度非法:" + messageLength);
return null;
}
// We have a full message in the read buffer, slice it off
ByteBuf messageBuffer = extractFrame(buffer, messageStartReaderIndex, messageLength);
// set real readerIndex
buffer.readerIndex(messageStartReaderIndex + messageLength);
// System.out.println("** 返回完整消息");
return messageBuffer;
}
protected ByteBuf extractFrame(ByteBuf buffer, int index, int length) {
// Slice should be sufficient here (and avoids the copy in
// LengthFieldBasedFrameDecoder)
// because we know no one is going to modify the contents in the read
// buffers.
return buffer.slice(index, length);
}
private static final ThriftMessageWrapper UNFRAMED_WRAPPER = new ThriftMessageWrapper() {
@Override
public Object wrapMessage(ChannelHandlerContext ctx, ThriftMessage msg) {
logger.debug("UNFRAMED_WRAPPER::wrapMessage");
return msg.getContent();
}
};
private static final ThriftMessageWrapper FRAMED_WRAPPER = new ThriftMessageWrapper() {
@Override
public void beforeMessageWrite(ChannelHandlerContext ctx, ThriftMessage msg) {
ByteBuf buf = msg.getContent();
buf.writerIndex(MESSAGE_FRAME_SIZE);
}
@Override
public Object wrapMessage(ChannelHandlerContext ctx, ThriftMessage msg) {
ByteBuf buf = msg.getContent();
final int size = buf.readableBytes() - MESSAGE_FRAME_SIZE;
final int writerIndex = buf.writerIndex();
logger.debug("framedMessage::wrapMessage , size={}, buf.writeables={}", size, buf.writableBytes());
buf.writerIndex(0);
buf.writeInt(size);
buf.writerIndex(writerIndex);
return buf;
}
};
private ThriftMessageWrapper unframedMessageWrapper(ThriftMessageWrapper successor) {
if (successor == null) {
return UNFRAMED_WRAPPER;
}
return new ThriftMessageWrapper(successor) {
@Override
public Object wrapMessageInner(ChannelHandlerContext ctx, ThriftMessage msg) {
return UNFRAMED_WRAPPER.wrapMessage(ctx, msg);
}
};
}
private ThriftMessageWrapper framedMessageWrapper(ThriftMessageWrapper successor) {
if (successor == null) {
return FRAMED_WRAPPER;
}
return new ThriftMessageWrapper(successor) {
@Override
public void beforeMessageWrite(ChannelHandlerContext ctx, ThriftMessage msg) {
FRAMED_WRAPPER.beforeMessageWrite(ctx, msg);
getSuccessor().beforeMessageWrite(ctx, msg);
}
@Override
protected Object wrapMessageInner(ChannelHandlerContext ctx, ThriftMessage msg) {
return FRAMED_WRAPPER.wrapMessage(ctx, msg);
}
};
}
}