/* * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 * (the "License"). You may not use this work except in compliance with the License, which is * available at www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied, as more fully set forth in the License. * * See the NOTICE file distributed with this work for information regarding copyright ownership. */ package alluxio.network.protocol; import alluxio.PropertyKey; import alluxio.BaseIntegrationTest; import alluxio.network.protocol.databuffer.DataByteBuffer; import alluxio.util.io.BufferUtils; import alluxio.util.network.NetworkAddressUtils; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.util.concurrent.TimeUnit; /** * This tests the encoding and decoding of RPCMessage's. This is done by setting up a simple * client/server bootstrap connection, and writing messages on the client side, and verifying it on * the server side. In this case, the server simply stores the message received, and does not reply * to the client. */ public class RPCMessageIntegrationTest extends BaseIntegrationTest { private static final long SESSION_ID = 10; private static final long BLOCK_ID = 11; private static final long OFFSET = 22; private static final long LENGTH = 33; private static final long LOCK_ID = 44; // This channel initializer sets up a simple pipeline with the encoder and decoder. private static class PipelineInitializer extends ChannelInitializer<SocketChannel> { private MessageSavingHandler mHandler = null; public PipelineInitializer(MessageSavingHandler handler) { mHandler = handler; } @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast("frameDecoder", RPCMessage.createFrameDecoder()); pipeline.addLast("RPCMessageDecoder", new RPCMessageDecoder()); pipeline.addLast("RPCMessageEncoder", new RPCMessageEncoder()); pipeline.addLast("handler", mHandler); } } @Rule public TemporaryFolder mFolder = new TemporaryFolder(); private Channel mOutgoingChannel; private static NioEventLoopGroup sEventClient; private static NioEventLoopGroup sEventServer; private static MessageSavingHandler sIncomingHandler; private static Bootstrap sBootstrapClient; private static SocketAddress sLocalAddress; @BeforeClass public static void beforeClass() { sEventClient = new NioEventLoopGroup(1); sEventServer = new NioEventLoopGroup(1); sIncomingHandler = new MessageSavingHandler(); // Setup the server. ServerBootstrap bootstrap = new ServerBootstrap(); bootstrap.group(sEventServer); bootstrap.channel(NioServerSocketChannel.class); bootstrap.childHandler(new PipelineInitializer(sIncomingHandler)); InetSocketAddress address = new InetSocketAddress(NetworkAddressUtils.getLocalHostName(100), Integer.parseInt(PropertyKey.MASTER_RPC_PORT.getDefaultValue())); ChannelFuture cf = bootstrap.bind(address).syncUninterruptibly(); sLocalAddress = cf.channel().localAddress(); // Setup the client. sBootstrapClient = new Bootstrap(); sBootstrapClient.group(sEventClient); sBootstrapClient.channel(NioSocketChannel.class); sBootstrapClient.handler(new PipelineInitializer(new MessageSavingHandler())); } @AfterClass public static void afterClass() { // Shut everything down. sEventClient.shutdownGracefully(0, 0, TimeUnit.SECONDS); sEventServer.shutdownGracefully(0, 0, TimeUnit.SECONDS); sEventClient.terminationFuture().syncUninterruptibly(); sEventServer.terminationFuture().syncUninterruptibly(); } @Before public final void before() { sIncomingHandler.reset(); // Connect to the server. ChannelFuture cf = sBootstrapClient.connect(sLocalAddress).syncUninterruptibly(); mOutgoingChannel = cf.channel(); } @After public final void after() { // Close the client connection. mOutgoingChannel.close().syncUninterruptibly(); } private void assertValid(RPCBlockReadRequest expected, RPCBlockReadRequest actual) { Assert.assertEquals(expected.getType(), actual.getType()); Assert.assertEquals(expected.getEncodedLength(), actual.getEncodedLength()); Assert.assertEquals(expected.getBlockId(), actual.getBlockId()); Assert.assertEquals(expected.getOffset(), actual.getOffset()); Assert.assertEquals(expected.getLength(), actual.getLength()); Assert.assertEquals(expected.getLockId(), actual.getLockId()); Assert.assertEquals(expected.getSessionId(), actual.getSessionId()); } private void assertValid(RPCBlockReadResponse expected, RPCBlockReadResponse actual) { Assert.assertEquals(expected.getType(), actual.getType()); Assert.assertEquals(expected.getEncodedLength(), actual.getEncodedLength()); Assert.assertEquals(expected.getBlockId(), actual.getBlockId()); Assert.assertEquals(expected.getOffset(), actual.getOffset()); Assert.assertEquals(expected.getLength(), actual.getLength()); Assert.assertEquals(expected.getStatus(), actual.getStatus()); if (expected.getLength() == 0) { // Length is 0, so payloads should be null. Assert.assertNull(expected.getPayloadDataBuffer()); Assert.assertNull(actual.getPayloadDataBuffer()); } else { Assert.assertTrue(BufferUtils.equalIncreasingByteBuffer((int) OFFSET, (int) LENGTH, actual .getPayloadDataBuffer().getReadOnlyByteBuffer())); } } private void assertValid(RPCBlockWriteRequest expected, RPCBlockWriteRequest actual) { Assert.assertEquals(expected.getType(), actual.getType()); Assert.assertEquals(expected.getEncodedLength(), actual.getEncodedLength()); Assert.assertEquals(expected.getBlockId(), actual.getBlockId()); Assert.assertEquals(expected.getOffset(), actual.getOffset()); Assert.assertEquals(expected.getLength(), actual.getLength()); Assert.assertEquals(expected.getSessionId(), actual.getSessionId()); if (expected.getLength() > 0) { Assert.assertTrue(BufferUtils.equalIncreasingByteBuffer((int) OFFSET, (int) LENGTH, actual .getPayloadDataBuffer().getReadOnlyByteBuffer())); } } private void assertValid(RPCBlockWriteResponse expected, RPCBlockWriteResponse actual) { Assert.assertEquals(expected.getType(), actual.getType()); Assert.assertEquals(expected.getEncodedLength(), actual.getEncodedLength()); Assert.assertEquals(expected.getBlockId(), actual.getBlockId()); Assert.assertEquals(expected.getOffset(), actual.getOffset()); Assert.assertEquals(expected.getLength(), actual.getLength()); Assert.assertEquals(expected.getSessionId(), actual.getSessionId()); Assert.assertEquals(expected.getStatus(), actual.getStatus()); } private void assertValid(RPCErrorResponse expected, RPCErrorResponse actual) { Assert.assertEquals(expected.getType(), actual.getType()); Assert.assertEquals(expected.getEncodedLength(), actual.getEncodedLength()); Assert.assertEquals(expected.getStatus(), actual.getStatus()); } /** * Encodes and decodes the 'msg' by sending it through the client and server pipelines. */ private RPCMessage encodeThenDecode(RPCMessage msg) { // Write the message to the outgoing channel. mOutgoingChannel.writeAndFlush(msg); // Read the decoded message from the incoming side. return sIncomingHandler.getMessage(); } @Test public void RPCBlockReadRequest() { RPCBlockReadRequest msg = new RPCBlockReadRequest(BLOCK_ID, OFFSET, LENGTH, LOCK_ID, SESSION_ID); RPCBlockReadRequest decoded = (RPCBlockReadRequest) encodeThenDecode(msg); assertValid(msg, decoded); } @Test public void RPCBlockReadResponse() { ByteBuffer payload = BufferUtils.getIncreasingByteBuffer((int) OFFSET, (int) LENGTH); RPCBlockReadResponse msg = new RPCBlockReadResponse(BLOCK_ID, OFFSET, LENGTH, new DataByteBuffer(payload, LENGTH), RPCResponse.Status.SUCCESS); RPCBlockReadResponse decoded = (RPCBlockReadResponse) encodeThenDecode(msg); assertValid(msg, decoded); } @Test public void RPCBlockReadResponseEmptyPayload() { RPCBlockReadResponse msg = new RPCBlockReadResponse(BLOCK_ID, OFFSET, 0, null, RPCResponse.Status.SUCCESS); RPCBlockReadResponse decoded = (RPCBlockReadResponse) encodeThenDecode(msg); assertValid(msg, decoded); } @Test public void RPCBlockReadResponseError() { RPCBlockReadResponse msg = RPCBlockReadResponse.createErrorResponse( new RPCBlockReadRequest(BLOCK_ID, OFFSET, LENGTH, LOCK_ID, SESSION_ID), RPCResponse.Status.FILE_DNE); RPCBlockReadResponse decoded = (RPCBlockReadResponse) encodeThenDecode(msg); assertValid(msg, decoded); } @Test public void RPCBlockWriteRequest() { ByteBuffer payload = BufferUtils.getIncreasingByteBuffer((int) OFFSET, (int) LENGTH); RPCBlockWriteRequest msg = new RPCBlockWriteRequest(SESSION_ID, BLOCK_ID, OFFSET, LENGTH, new DataByteBuffer(payload, LENGTH)); RPCBlockWriteRequest decoded = (RPCBlockWriteRequest) encodeThenDecode(msg); assertValid(msg, decoded); } @Test public void RPCBlockWriteResponse() { RPCBlockWriteResponse msg = new RPCBlockWriteResponse(SESSION_ID, BLOCK_ID, OFFSET, LENGTH, RPCResponse.Status.SUCCESS); RPCBlockWriteResponse decoded = (RPCBlockWriteResponse) encodeThenDecode(msg); assertValid(msg, decoded); } @Test public void RPCErrorResponse() { for (RPCResponse.Status status : RPCResponse.Status.values()) { RPCErrorResponse msg = new RPCErrorResponse(status); RPCErrorResponse decoded = (RPCErrorResponse) encodeThenDecode(msg); assertValid(msg, decoded); } } }