/* * Copyright (c) 2015-2016, Christoph Engelbert (aka noctarius) and * contributors. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.noctarius.tengi.server.impl.transport.websocket; import com.noctarius.tengi.core.model.Identifier; import com.noctarius.tengi.core.model.Message; import com.noctarius.tengi.core.model.Packet; import com.noctarius.tengi.server.ServerTransports; import com.noctarius.tengi.server.impl.transport.AbstractStreamingTransportTestCase; import com.noctarius.tengi.spi.buffer.MemoryBuffer; import com.noctarius.tengi.spi.buffer.impl.MemoryBufferFactory; import com.noctarius.tengi.spi.connection.impl.TransportConstants; import com.noctarius.tengi.spi.connection.packets.Handshake; import com.noctarius.tengi.spi.serialization.Serializer; import com.noctarius.tengi.spi.serialization.codec.AutoClosableEncoder; import com.noctarius.tengi.spi.serialization.codec.impl.DefaultCodec; import com.noctarius.tengi.spi.serialization.impl.DefaultProtocol; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.junit.Test; import javax.websocket.ClientEndpoint; import javax.websocket.ContainerProvider; import javax.websocket.OnClose; import javax.websocket.OnMessage; import javax.websocket.Session; import javax.websocket.WebSocketContainer; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import static org.junit.Assert.assertEquals; public class WebsocketTransportTestCase extends AbstractStreamingTransportTestCase { @Test(timeout = 120000) public void test_websocket_transport() throws Exception { Serializer serializer = Serializer.create(new DefaultProtocol(Collections.emptyList())); CompletableFuture<Object> future = new CompletableFuture<>(); Packet packet = new Packet("login"); packet.setValue("username", "Stan"); Message message = Message.create(packet); ChannelReader<WebsocketTestClient, ByteBuf> channelReader = (client, buffer) -> { MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); boolean loggedIn = codec.readBoolean(); Identifier connectionId = codec.readObject(); Object object = codec.readObject(); if (loggedIn && object instanceof Handshake) { writeChannel(serializer, client, connectionId, message); return; } future.complete(object); }; Runner<Object, WebsocketTestClient> runner = (client) -> { ByteBuf buffer = Unpooled.buffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); codec.writeBoolean("loggedIn", false); codec.writeObject("handshake", new Handshake()); client.sendMessage(buffer); Object result = future.get(120, TimeUnit.SECONDS); client.close(); return result; }; ClientFactory<WebsocketTestClient> clientFactory = clientFactory(channelReader, (s) -> future.complete(null)); Object response = practice(runner, clientFactory, false, ServerTransports.WEBSOCKET_TRANSPORT); assertEquals(message, response); } @Test(timeout = 120000) public void test_websocket_transport_ping_pong() throws Exception { Serializer serializer = Serializer.create(new DefaultProtocol(Collections.emptyList())); CompletableFuture<Packet> future = new CompletableFuture<>(); ChannelReader<WebsocketTestClient, ByteBuf> channelReader = (client, buffer) -> { MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); boolean loggedIn = codec.readBoolean(); Identifier connectionId = codec.readObject(); Object object = codec.readObject(); if (object instanceof Handshake) { Packet packet = new Packet("pingpong"); packet.setValue("counter", 1); Message message = Message.create(packet); writeChannel(serializer, client, connectionId, message); return; } Message message = (Message) object; Packet packet = message.getBody(); int counter = packet.getValue("counter"); if (counter == 4) { future.complete(packet); } else { packet.setValue("counter", counter + 1); message = Message.create(packet); ByteBuf buffer2 = Unpooled.buffer(); MemoryBuffer memoryBuffer2 = MemoryBufferFactory.create(buffer2); DefaultCodec codec2 = new DefaultCodec(serializer.getProtocol(), memoryBuffer2); codec2.writeBoolean("loggedIn", loggedIn); codec2.writeObject("connectionId", connectionId); serializer.writeObject("message", message, codec2); client.sendMessage(buffer2); } }; Runner<Packet, WebsocketTestClient> runner = (client) -> { ByteBuf buffer = Unpooled.buffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); DefaultCodec codec = new DefaultCodec(serializer.getProtocol(), memoryBuffer); codec.writeBoolean("loggedIn", false); codec.writeObject("handshake", new Handshake()); client.sendMessage(buffer); Packet result = future.get(120, TimeUnit.SECONDS); client.close(); return result; }; ClientFactory<WebsocketTestClient> clientFactory = clientFactory(channelReader, (s) -> future.complete(null)); Packet response = practice(runner, clientFactory, false, ServerTransports.WEBSOCKET_TRANSPORT); assertEquals(4, (int) response.getValue("counter")); } private static void writeChannel(Serializer serializer, WebsocketTestClient client, Identifier connectionId, Object value) throws Exception { ByteBuf buffer = Unpooled.directBuffer(); MemoryBuffer memoryBuffer = MemoryBufferFactory.create(buffer); try (AutoClosableEncoder encoder = serializer.retrieveEncoder(memoryBuffer)) { encoder.writeBoolean("loggedIn", true); encoder.writeObject("connectionId", connectionId); serializer.writeObject("value", value, encoder); } client.sendMessage(buffer); } private static ClientFactory<WebsocketTestClient> clientFactory(ChannelReader<WebsocketTestClient, ByteBuf> channelReader, Consumer<Session> closeListener) { return (host, port, ssl, group) -> new WebsocketTestClient(host, port, ssl, channelReader, closeListener); } @ClientEndpoint public static class WebsocketTestClient { private final Session session; private final Consumer<Session> closeListener; private final ChannelReader<WebsocketTestClient, ByteBuf> channelReader; private WebsocketTestClient(String host, int port, boolean ssl, ChannelReader<WebsocketTestClient, ByteBuf> channelReader, Consumer<Session> closeListener) throws Exception { WebSocketContainer container = ContainerProvider.getWebSocketContainer(); this.session = container.connectToServer(this, createURI(host, port, ssl)); this.channelReader = channelReader; this.closeListener = closeListener; } private URI createURI(String host, int port, boolean ssl) { String url = (ssl ? "wss" : "ws") + "://" + host + ":" + port + TransportConstants.WEBSOCKET_RELATIVE_PATH; return URI.create(url); } @OnMessage public void onMessage(ByteBuffer byteBuffer) throws Exception { ByteBuf buffer = Unpooled.wrappedBuffer(byteBuffer); channelReader.channelRead(this, buffer); } private void sendMessage(ByteBuf buffer) { ByteBuffer nioBuffer = buffer.nioBuffer(); if (buffer.isDirect()) { nioBuffer = Unpooled.copiedBuffer(buffer).nioBuffer(); } session.getAsyncRemote().sendBinary(nioBuffer); } @OnClose private void onClose() { closeListener.accept(session); } private void close() throws IOException { session.close(); } } }