/* * Copyright (c) 2015-2016, Christoph Engelbert (aka noctarius) and * contributors. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.noctarius.tengi.server.impl.transport.udt; import com.barchart.udt.SocketUDT; import com.barchart.udt.TypeUDT; import com.noctarius.tengi.core.model.Identifier; import com.noctarius.tengi.core.model.Message; import com.noctarius.tengi.core.model.Packet; import com.noctarius.tengi.server.ServerTransports; import com.noctarius.tengi.server.impl.transport.AbstractStreamingTransportTestCase; import com.noctarius.tengi.spi.buffer.MemoryBuffer; import com.noctarius.tengi.spi.buffer.impl.MemoryBufferFactory; import com.noctarius.tengi.spi.connection.packets.Handshake; import com.noctarius.tengi.spi.serialization.Serializer; import com.noctarius.tengi.spi.serialization.codec.AutoClosableEncoder; import com.noctarius.tengi.spi.serialization.codec.impl.DefaultCodec; import com.noctarius.tengi.spi.serialization.impl.DefaultProtocol; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.junit.Test; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertEquals; public class UdtTransportTestCase extends AbstractStreamingTransportTestCase { @Test(timeout = 120000) public void test_udt_transport() throws Exception { Serializer serializer = Serializer.create(new DefaultProtocol(Collections.emptyList())); CompletableFuture<Object> future = new CompletableFuture<>(); Packet packet = new Packet("login"); packet.setValue("username", "Stan"); Message message = Message.create(packet); ChannelReader<UdtTestClient, ByteBuf> channelReader = (client, buffer) -> { MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); boolean loggedIn = codec.readBoolean(); Identifier connectionId = codec.readObject(); Object object = codec.readObject(); if (loggedIn && object instanceof Handshake) { writeChannel(serializer, client, connectionId, message); return; } future.complete(object); }; Runner<Object, UdtTestClient> runner = (client) -> { ByteBuf buffer = Unpooled.buffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); codec.writeBoolean("loggedIn", false); codec.writeObject("handshake", new Handshake()); client.sendMessage(buffer); Object result = future.get(120, TimeUnit.SECONDS); client.close(); return result; }; Object response = practice(runner, clientFactory(channelReader), false, ServerTransports.UDT_TRANSPORT); assertEquals(message, response); Thread.sleep(5000); } @Test(timeout = 120000) public void test_udt_transport_ping_pong() throws Exception { Serializer serializer = Serializer.create(new DefaultProtocol(Collections.emptyList())); CompletableFuture<Packet> future = new CompletableFuture<>(); ChannelReader<UdtTestClient, ByteBuf> channelReader = (client, buffer) -> { MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); boolean loggedIn = codec.readBoolean(); Identifier connectionId = codec.readObject(); Object object = codec.readObject(); if (object instanceof Handshake) { Packet packet = new Packet("pingpong"); packet.setValue("counter", 1); Message message = Message.create(packet); writeChannel(serializer, client, connectionId, message); return; } Message message = (Message) object; Packet packet = message.getBody(); int counter = packet.getValue("counter"); if (counter == 4) { future.complete(packet); } else { packet.setValue("counter", counter + 1); message = Message.create(packet); ByteBuf buffer2 = Unpooled.buffer(); MemoryBuffer memoryBuffer2 = MemoryBufferFactory.create(buffer2); DefaultCodec codec2 = new DefaultCodec(serializer.getProtocol(), memoryBuffer2); codec2.writeBoolean("loggedIn", loggedIn); codec2.writeObject("connectionId", connectionId); serializer.writeObject("message", message, codec2); client.sendMessage(buffer2); } }; Runner<Packet, UdtTestClient> runner = (client) -> { ByteBuf buffer = Unpooled.buffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); codec.writeBoolean("loggedIn", false); codec.writeObject("handshake", new Handshake()); client.sendMessage(buffer); Packet result = future.get(120, TimeUnit.SECONDS); client.close(); return result; }; Packet response = practice(runner, clientFactory(channelReader), false, ServerTransports.UDT_TRANSPORT); assertEquals(4, (int) response.getValue("counter")); } private static void writeChannel(Serializer serializer, UdtTestClient client, Identifier connectionId, Object value) throws Exception { ByteBuf buffer = Unpooled.directBuffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); try (AutoClosableEncoder encoder = serializer.retrieveEncoder(memoryBuffer)) { encoder.writeBoolean("loggedIn", true); encoder.writeObject("connectionId", connectionId); serializer.writeObject("value", value, encoder); } client.sendMessage(buffer); } private static ClientFactory<UdtTestClient> clientFactory(ChannelReader<UdtTestClient, ByteBuf> channelReader) { return (host, port, ssl, group) -> new UdtTestClient(host, port, ssl, channelReader); } public static class UdtTestClient { private final SocketUDT socket; private final ChannelReader<UdtTestClient, ByteBuf> channelReader; private final AtomicBoolean stop = new AtomicBoolean(false); private UdtTestClient(String host, int port, boolean ssl, ChannelReader<UdtTestClient, ByteBuf> channelReader) throws Exception { this.channelReader = channelReader; this.socket = new SocketUDT(TypeUDT.STREAM); this.socket.setBlocking(false); this.socket.connect(new InetSocketAddress(host, port)); while (!this.socket.isConnected()) { Thread.yield(); } new Thread(new Worker()).start(); } private void sendMessage(ByteBuf buffer) throws Exception { ByteBuffer nioBuffer = Unpooled.directBuffer(buffer.writerIndex()).writeBytes(buffer).nioBuffer(); socket.send(nioBuffer); } private void close() throws Exception { stop.set(true); socket.close(); } private final class Worker implements Runnable { @Override public void run() { ByteBuffer buffer = ByteBuffer.allocateDirect(1024); while (!stop.get()) { try { if (socket.isConnected() && socket.receive(buffer) > 0) { buffer.flip(); ByteBuf buf = Unpooled.copiedBuffer(buffer); channelReader.channelRead(UdtTestClient.this, buf); buffer.clear(); } } catch (Exception e) { e.printStackTrace(); } } } } } }