/* * Copyright (C) 2012-2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.nifty.core; import com.facebook.nifty.codec.DefaultThriftFrameDecoder; import com.facebook.nifty.codec.ThriftFrameDecoder; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TField; import org.apache.thrift.protocol.TMessage; import org.apache.thrift.protocol.TMessageType; import org.apache.thrift.protocol.TStruct; import org.apache.thrift.protocol.TType; import org.apache.thrift.transport.TFramedTransport; import org.easymock.EasyMock; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.AbstractChannelSink; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelEvent; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.DefaultChannelConfig; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.testng.Assert; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import java.net.InetSocketAddress; import java.util.concurrent.atomic.AtomicInteger; public class TestThriftFrameDecoder { private Channel channel; private AtomicInteger messagesReceived; private AtomicInteger exceptionsCaught; private static final int MAX_FRAME_SIZE = 1024; public static final int MESSAGE_CHUNK_SIZE = 10; // Send an empty buffer and make sure nothing breaks, and @Test public void testDecodeEmptyBuffer() throws Exception { Channels.fireMessageReceived(channel, ChannelBuffers.EMPTY_BUFFER); Assert.assertEquals(exceptionsCaught.get(), 0); Assert.assertEquals(messagesReceived.get(), 0); } // Send two unframed messages in a single buffer, and check they both get decoded @Test public void testDecodeUnframedMessages() throws Exception { TChannelBufferOutputTransport transport = new TChannelBufferOutputTransport(); TBinaryProtocol protocol = new TBinaryProtocol(transport); writeTestMessages(protocol, 2); Channels.fireMessageReceived(channel, transport.getOutputBuffer()); Assert.assertEquals(exceptionsCaught.get(), 0); Assert.assertEquals(messagesReceived.get(), 2); } // Send two framed messages in a single buffer, and check they both get decoded @Test public void testDecodeFramedMessages() throws Exception { TChannelBufferOutputTransport transport = new TChannelBufferOutputTransport(); TBinaryProtocol protocol = new TBinaryProtocol(new TFramedTransport(transport)); writeTestMessages(protocol, 2); Channels.fireMessageReceived(channel, transport.getOutputBuffer()); Assert.assertEquals(messagesReceived.get(), 2); } // Send three unframed messages, chunked into 10-byte buffers and make sure they all get decoded @Test public void testDecodeChunkedUnframedMessages() throws Exception { TChannelBufferOutputTransport transport = new TChannelBufferOutputTransport(); TBinaryProtocol protocol = new TBinaryProtocol(transport); writeTestMessages(protocol, 3); sendMessagesInChunks(channel, transport, MESSAGE_CHUNK_SIZE); Assert.assertEquals(messagesReceived.get(), 3); } // Send three framed messages, chunked into 10-byte buffers and make sure they all get decoded @Test public void testDecodeChunkedFramedMessages() throws Exception { TChannelBufferOutputTransport transport = new TChannelBufferOutputTransport(); TBinaryProtocol protocol = new TBinaryProtocol(new TFramedTransport(transport)); writeTestMessages(protocol, 3); sendMessagesInChunks(channel, transport, MESSAGE_CHUNK_SIZE); Assert.assertEquals(messagesReceived.get(), 3); } private void sendMessagesInChunks(Channel channel, TChannelBufferOutputTransport transport, int chunkSize) { ChannelBuffer buffer = transport.getOutputBuffer(); while (buffer.readable()) { ChannelBuffer chunk = buffer.readSlice(Math.min(chunkSize, buffer.readableBytes())); Channels.fireMessageReceived(channel, chunk); } } private void writeTestMessages(TBinaryProtocol protocol, int count) throws TException { for (int i = 0; i < count; i++) { protocol.writeMessageBegin(new TMessage("testmessage" + i, TMessageType.CALL, i)); { protocol.writeStructBegin(new TStruct()); { protocol.writeFieldBegin(new TField("i32field", TType.I32, (short) 1)); protocol.writeI32(123); protocol.writeFieldEnd(); } { protocol.writeFieldBegin(new TField("strfield", TType.STRING, (short) 2)); protocol.writeString("foo"); protocol.writeFieldEnd(); } { protocol.writeFieldBegin(new TField("boolfield", TType.BOOL, (short) 3)); protocol.writeBool(true); protocol.writeFieldEnd(); } protocol.writeFieldStop(); protocol.writeStructEnd(); } protocol.writeMessageEnd(); protocol.getTransport().flush(); } } @BeforeMethod(alwaysRun = true) public void setUp() { ThriftFrameDecoder decoder = new DefaultThriftFrameDecoder(MAX_FRAME_SIZE, new TBinaryProtocol.Factory()); ChannelPipeline pipeline = Channels.pipeline( decoder, new SimpleChannelUpstreamHandler() { @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { messagesReceived.incrementAndGet(); super.messageReceived(ctx, e); } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { exceptionsCaught.incrementAndGet(); super.exceptionCaught(ctx, e); } } ); InetSocketAddress fakeRemoteAddress = new InetSocketAddress("localhost", 1234); exceptionsCaught = new AtomicInteger(0); messagesReceived = new AtomicInteger(0); channel = EasyMock.createMock(Channel.class); EasyMock.expect(channel.getRemoteAddress()).andReturn(fakeRemoteAddress).anyTimes(); EasyMock.expect(channel.getPipeline()).andReturn(pipeline).anyTimes(); EasyMock.expect(channel.getConfig()).andReturn(new DefaultChannelConfig()).anyTimes(); EasyMock.replay(channel); pipeline.attach(channel, new AbstractChannelSink() { @Override public void eventSunk(ChannelPipeline pipeline, ChannelEvent e) { return; } }); } }