/* * Copyright (c) 2015-2016, Christoph Engelbert (aka noctarius) and * contributors. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.noctarius.tengi.server.impl.transport.tcp; import com.noctarius.tengi.core.model.Identifier; import com.noctarius.tengi.core.model.Message; import com.noctarius.tengi.core.model.Packet; import com.noctarius.tengi.server.ServerTransports; import com.noctarius.tengi.server.impl.transport.AbstractStreamingTransportTestCase; import com.noctarius.tengi.spi.buffer.MemoryBuffer; import com.noctarius.tengi.spi.buffer.impl.MemoryBufferFactory; import com.noctarius.tengi.spi.connection.packets.Handshake; import com.noctarius.tengi.spi.serialization.Serializer; import com.noctarius.tengi.spi.serialization.codec.AutoClosableEncoder; import com.noctarius.tengi.spi.serialization.codec.impl.DefaultCodec; import com.noctarius.tengi.spi.serialization.impl.DefaultProtocol; import com.noctarius.tengi.spi.serialization.impl.DefaultProtocolConstants; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import org.junit.Test; import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; public class TcpTransportTestCase extends AbstractStreamingTransportTestCase { @Test(timeout = 120000) public void test_tcp_transport() throws Exception { Serializer serializer = Serializer.create(new DefaultProtocol(Collections.emptyList())); CompletableFuture<Object> future = new CompletableFuture<>(); Packet packet = new Packet("login"); packet.setValue("username", "Stan"); Message message = Message.create(packet); ChannelReader<ChannelHandlerContext, ByteBuf> channelReader = (ctx, buffer) -> { MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); boolean loggedIn = codec.readBoolean(); Identifier connectionId = codec.readObject(); Object object = codec.readObject(); if (loggedIn && object instanceof Handshake) { writeChannel(serializer, ctx, connectionId, message); return; } future.complete(object); }; Initializer initializer = (pipeline) -> pipeline.addLast(inboundHandler(channelReader)); Runner<Object, Channel> runner = (channel) -> { ByteBuf buffer = Unpooled.buffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); codec.writeBytes("magic", DefaultProtocolConstants.PROTOCOL_MAGIC_HEADER); codec.writeBoolean("loggedIn", false); codec.writeObject("handshake", new Handshake()); channel.writeAndFlush(buffer); channel.closeFuture().addListener((ChannelFutureListener) (f) -> future.complete(null)); Object result = future.get(120, TimeUnit.SECONDS); channel.close().sync(); return result; }; Object response = practice(runner, clientFactory(initializer), false, ServerTransports.TCP_TRANSPORT); assertEquals(message, response); } @Test(timeout = 120000) public void test_tcp_transport_ping_pong() throws Exception { Serializer serializer = Serializer.create(new DefaultProtocol(Collections.emptyList())); CompletableFuture<Packet> future = new CompletableFuture<>(); ChannelReader<ChannelHandlerContext, ByteBuf> channelReader = (ctx, buffer) -> { MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); boolean loggedIn = codec.readBoolean(); Identifier connectionId = codec.readObject(); Object object = codec.readObject(); if (object instanceof Handshake) { Packet packet = new Packet("pingpong"); packet.setValue("counter", 1); Message message = Message.create(packet); writeChannel(serializer, ctx, connectionId, message); return; } Message message = (Message) object; Packet packet = message.getBody(); int counter = packet.getValue("counter"); if (counter == 4) { future.complete(packet); } else { packet.setValue("counter", counter + 1); message = Message.create(packet); ByteBuf buffer2 = Unpooled.buffer(); MemoryBuffer memoryBuffer2 = MemoryBufferFactory.create(buffer2); DefaultCodec codec2 = new DefaultCodec(serializer.getProtocol(), memoryBuffer2); codec2.writeBoolean("loggedIn", loggedIn); codec2.writeObject("connectionId", connectionId); serializer.writeObject("message", message, codec2); ctx.channel().writeAndFlush(buffer2); } }; Initializer initializer = (pipeline) -> pipeline.addLast(inboundHandler(channelReader)); Runner<Packet, Channel> runner = (channel) -> { ByteBuf buffer = Unpooled.buffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); codec.writeBytes("magic", DefaultProtocolConstants.PROTOCOL_MAGIC_HEADER); codec.writeBoolean("loggedIn", false); codec.writeObject("handshake", new Handshake()); channel.writeAndFlush(buffer); channel.closeFuture().addListener((ChannelFutureListener) (f) -> future.complete(null)); Packet result = future.get(120, TimeUnit.SECONDS); channel.close().sync(); return result; }; Packet response = practice(runner, clientFactory(initializer), false, ServerTransports.TCP_TRANSPORT); assertEquals(4, (int) response.getValue("counter")); } protected static <T> SimpleChannelInboundHandler<T> inboundHandler(ChannelReader<ChannelHandlerContext, T> channelReader) { return new SimpleChannelInboundHandler<T>() { @Override protected void channelRead0(ChannelHandlerContext ctx, T object) throws Exception { channelReader.channelRead(ctx, object); } }; } private static void writeChannel(Serializer serializer, ChannelHandlerContext ctx, Identifier connectionId, Object value) throws Exception { ByteBuf buffer = ctx.alloc().directBuffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); try (AutoClosableEncoder encoder = serializer.retrieveEncoder(memoryBuffer)) { encoder.writeBoolean("loggedIn", true); encoder.writeObject("connectionId", connectionId); serializer.writeObject("value", value, encoder); } ctx.channel().writeAndFlush(buffer); } private static ClientFactory<Channel> clientFactory(Initializer initializer) { return (host, port, ssl, group) -> { Bootstrap bootstrap = new Bootstrap().group(group) // .channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel channel) throws Exception { ChannelPipeline pipeline = channel.pipeline(); if (ssl) { SslContext sslContext = SslContextBuilder // .forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); pipeline.addLast(sslContext.newHandler(channel.alloc(), "localhost", 8080)); } initializer.initChannel(pipeline); } }); ChannelFuture future = bootstrap.connect("localhost", 8080); return future.sync().channel(); }; } private static interface Initializer { void initChannel(ChannelPipeline pipeline) throws Exception; } }