/* * Copyright (C) 2012-2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.nifty.codec; import com.facebook.nifty.core.TNiftyTransport; import com.facebook.nifty.core.ThriftMessage; import com.facebook.nifty.core.ThriftTransportType; import org.apache.thrift.TException; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.protocol.TProtocolUtil; import org.apache.thrift.protocol.TType; import org.apache.thrift.transport.TTransportException; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.Channels; import org.jboss.netty.handler.codec.frame.TooLongFrameException; public class DefaultThriftFrameDecoder extends ThriftFrameDecoder { public static final int MESSAGE_FRAME_SIZE = 4; private final int maxFrameSize; private final TProtocolFactory inputProtocolFactory; public DefaultThriftFrameDecoder(int maxFrameSize, TProtocolFactory inputProtocolFactory) { this.maxFrameSize = maxFrameSize; this.inputProtocolFactory = inputProtocolFactory; } @Override protected ThriftMessage decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception { if (!buffer.readable()) { return null; } short firstByte = buffer.getUnsignedByte(0); if (firstByte >= 0x80) { ChannelBuffer messageBuffer = tryDecodeUnframedMessage(ctx, channel, buffer, inputProtocolFactory); if (messageBuffer == null) { return null; } // A non-zero MSB for the first byte of the message implies the message starts with a // protocol id (and thus it is unframed). return new ThriftMessage(messageBuffer,ThriftTransportType.UNFRAMED); } else if (buffer.readableBytes() < MESSAGE_FRAME_SIZE) { // Expecting a framed message, but not enough bytes available to read the frame size return null; } else { ChannelBuffer messageBuffer = tryDecodeFramedMessage(ctx, channel, buffer, true); if (messageBuffer == null) { return null; } // Messages with a zero MSB in the first byte are framed messages return new ThriftMessage(messageBuffer, ThriftTransportType.FRAMED); } } protected ChannelBuffer tryDecodeFramedMessage(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, boolean stripFraming) { // Framed messages are prefixed by the size of the frame (which doesn't include the // framing itself). int messageStartReaderIndex = buffer.readerIndex(); int messageContentsOffset; if (stripFraming) { messageContentsOffset = messageStartReaderIndex + MESSAGE_FRAME_SIZE; } else { messageContentsOffset = messageStartReaderIndex; } // The full message is larger by the size of the frame size prefix int messageLength = buffer.getInt(messageStartReaderIndex) + MESSAGE_FRAME_SIZE; int messageContentsLength = messageStartReaderIndex + messageLength - messageContentsOffset; if (messageContentsLength > maxFrameSize) { Channels.fireExceptionCaught( ctx, new TooLongFrameException("Maximum frame size of " + maxFrameSize + " exceeded") ); } if (messageLength == 0) { // Zero-sized frame: just ignore it and return nothing buffer.readerIndex(messageContentsOffset); return null; } else if (buffer.readableBytes() < messageLength) { // Full message isn't available yet, return nothing for now return null; } else { // Full message is available, return it ChannelBuffer messageBuffer = extractFrame(buffer, messageContentsOffset, messageContentsLength); buffer.readerIndex(messageStartReaderIndex + messageLength); return messageBuffer; } } protected ChannelBuffer tryDecodeUnframedMessage(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, TProtocolFactory inputProtocolFactory) throws TException { // Perform a trial decode, skipping through // the fields, to see whether we have an entire message available. int messageLength = 0; int messageStartReaderIndex = buffer.readerIndex(); try { TNiftyTransport decodeAttemptTransport = new TNiftyTransport(channel, buffer, ThriftTransportType.UNFRAMED); int initialReadBytes = decodeAttemptTransport.getReadByteCount(); TProtocol inputProtocol = inputProtocolFactory.getProtocol(decodeAttemptTransport); // Skip through the message inputProtocol.readMessageBegin(); TProtocolUtil.skip(inputProtocol, TType.STRUCT); inputProtocol.readMessageEnd(); messageLength = decodeAttemptTransport.getReadByteCount() - initialReadBytes; } catch (TTransportException | IndexOutOfBoundsException e) { // No complete message was decoded: ran out of bytes return null; } finally { if (buffer.readerIndex() - messageStartReaderIndex > maxFrameSize) { Channels.fireExceptionCaught( ctx, new TooLongFrameException("Maximum frame size of " + maxFrameSize + " exceeded") ); } buffer.readerIndex(messageStartReaderIndex); } if (messageLength <= 0) { return null; } // We have a full message in the read buffer, slice it off ChannelBuffer messageBuffer = extractFrame(buffer, messageStartReaderIndex, messageLength); buffer.readerIndex(messageStartReaderIndex + messageLength); return messageBuffer; } protected ChannelBuffer extractFrame(ChannelBuffer buffer, int index, int length) { // Slice should be sufficient here (and avoids the copy in LengthFieldBasedFrameDecoder) // because we know no one is going to modify the contents in the read buffers. return buffer.slice(index, length); } }