/* * Copyright 2014 Red Hat, Inc. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * and Apache License v2.0 which accompanies this distribution. * * The Eclipse Public License is available at * http://www.eclipse.org/legal/epl-v10.html * * The Apache License v2.0 is available at * http://www.opensource.org/licenses/apache2.0.php * * You may elect to redistribute this code under either of these licenses. */ package io.vertx.test.core; import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; import io.vertx.core.AbstractVerticle; import io.vertx.core.AsyncResult; import io.vertx.core.Context; import io.vertx.core.DeploymentOptions; import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Vertx; import io.vertx.core.VertxOptions; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.*; import io.vertx.core.impl.ConcurrentHashSet; import io.vertx.core.net.NetServer; import io.vertx.core.net.NetSocket; import io.vertx.core.streams.ReadStream; import io.vertx.test.core.tls.Cert; import io.vertx.test.core.tls.Trust; import org.junit.Test; import javax.net.ssl.SSLPeerUnverifiedException; import javax.security.cert.X509Certificate; import java.io.UnsupportedEncodingException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import static io.vertx.test.core.TestUtils.*; /** * @author <a href="http://tfox.org">Tim Fox</a> */ public class WebsocketTest extends VertxTestBase { private HttpClient client; private HttpServer server; private NetServer netServer; public void setUp() throws Exception { super.setUp(); client = vertx.createHttpClient(new HttpClientOptions()); } protected void tearDown() throws Exception { client.close(); if (server != null) { CountDownLatch latch = new CountDownLatch(1); server.close(ar -> { assertTrue(ar.succeeded()); latch.countDown(); }); awaitLatch(latch); } if (netServer != null) { CountDownLatch latch = new CountDownLatch(1); netServer.close(ar -> { assertTrue(ar.succeeded()); latch.countDown(); }); awaitLatch(latch); } super.tearDown(); } @Override protected VertxOptions getOptions() { VertxOptions options = super.getOptions(); options.getAddressResolverOptions().setHostsValue(Buffer.buffer("" + "127.0.0.1 localhost\n" + "127.0.0.1 host2.com")); return options; } @Test public void testRejectHybi00() throws Exception { testReject(WebsocketVersion.V00); } @Test public void testRejectHybi08() throws Exception { testReject(WebsocketVersion.V08); } @Test public void testWSBinaryHybi00() throws Exception { testWSFrames(true, WebsocketVersion.V00); } @Test public void testWSStringHybi00() throws Exception { testWSFrames(false, WebsocketVersion.V00); } @Test public void testWSBinaryHybi08() throws Exception { testWSFrames(true, WebsocketVersion.V08); } @Test public void testWSStringHybi08() throws Exception { testWSFrames(false, WebsocketVersion.V08); } @Test public void testWSBinaryHybi17() throws Exception { testWSFrames(true, WebsocketVersion.V13); } @Test public void testWSStringHybi17() throws Exception { testWSFrames(false, WebsocketVersion.V13); } @Test public void testWSStreamsHybi00() throws Exception { testWSWriteStream(WebsocketVersion.V00); } @Test public void testWSStreamsHybi08() throws Exception { testWSWriteStream(WebsocketVersion.V08); } @Test public void testWSStreamsHybi17() throws Exception { testWSWriteStream(WebsocketVersion.V13); } @Test public void testWriteFromConnectHybi00() throws Exception { testWriteFromConnectHandler(WebsocketVersion.V00); } @Test public void testWriteFromConnectHybi08() throws Exception { testWriteFromConnectHandler(WebsocketVersion.V08); } @Test public void testWriteFromConnectHybi17() throws Exception { testWriteFromConnectHandler(WebsocketVersion.V13); } @Test public void testContinuationWriteFromConnectHybi08() throws Exception { testContinuationWriteFromConnectHandler(WebsocketVersion.V08); } @Test public void testContinuationWriteFromConnectHybi17() throws Exception { testContinuationWriteFromConnectHandler(WebsocketVersion.V13); } @Test public void testValidSubProtocolHybi00() throws Exception { testValidSubProtocol(WebsocketVersion.V00); } @Test public void testValidSubProtocolHybi08() throws Exception { testValidSubProtocol(WebsocketVersion.V08); } @Test public void testValidSubProtocolHybi17() throws Exception { testValidSubProtocol(WebsocketVersion.V13); } @Test public void testInvalidSubProtocolHybi00() throws Exception { testInvalidSubProtocol(WebsocketVersion.V00); } @Test public void testInvalidSubProtocolHybi08() throws Exception { testInvalidSubProtocol(WebsocketVersion.V08); } @Test public void testInvalidSubProtocolHybi17() throws Exception { testInvalidSubProtocol(WebsocketVersion.V13); } // TODO close and exception tests // TODO pause/resume/drain tests @Test // Client trusts all server certs public void testTLSClientTrustAll() throws Exception { testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, true, false, true); } @Test // Server specifies cert that the client trusts (not trust all) public void testTLSClientTrustServerCert() throws Exception { testTLS(Cert.NONE, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.NONE, false, false, false, false, true); } @Test // Server specifies cert that the client trusts (not trust all) public void testTLSClientTrustServerCertWithSNI() throws Exception { testTLS(Cert.NONE, Trust.SNI_JKS_HOST2, Cert.SNI_JKS, Trust.NONE, false, false, false, false, true, true, true, true, new String[0], client -> client.websocketStream(4043, "host2.com", "/")); } @Test // Server specifies cert that the client trusts (not trust all) public void testTLSClientTrustServerCertPKCS12() throws Exception { testTLS(Cert.NONE, Trust.SERVER_JKS, Cert.SERVER_PKCS12, Trust.NONE, false, false, false, false, true); } @Test // Server specifies cert that the client trusts (not trust all) public void testTLSClientTrustServerCertPEM() throws Exception { testTLS(Cert.NONE, Trust.SERVER_JKS, Cert.SERVER_PEM, Trust.NONE, false, false, false, false, true); } @Test // Server specifies cert that the client trusts via a CA (not trust all) public void testTLSClientTrustServerCertPEM_CA() throws Exception { testTLS(Cert.NONE, Trust.SERVER_PEM_ROOT_CA, Cert.SERVER_PEM_ROOT_CA, Trust.NONE, false, false, false, false, true); } @Test // Server specifies cert that the client trusts (not trust all) public void testTLSClientTrustPKCS12ServerCert() throws Exception { testTLS(Cert.NONE, Trust.SERVER_PKCS12, Cert.SERVER_JKS, Trust.NONE, false, false, false, false, true); } @Test // Server specifies cert that the client trusts (not trust all) public void testTLSClientTrustPEMServerCert() throws Exception { testTLS(Cert.NONE, Trust.SERVER_PEM, Cert.SERVER_JKS, Trust.NONE, false, false, false, false, true); } @Test // Server specifies cert that the client doesn't trust public void testTLSClientUntrustedServer() throws Exception { testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, false, false, false); } @Test //Client specifies cert even though it's not required public void testTLSClientCertNotRequired() throws Exception { testTLS(Cert.CLIENT_JKS, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_JKS, false, false, false, false, true); } @Test //Client specifies cert and it is required public void testTLSClientCertRequired() throws Exception { testTLS(Cert.CLIENT_JKS, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_JKS, true, false, false, false, true); } @Test //Client specifies cert and it is required public void testTLSClientCertRequiredPKCS12() throws Exception { testTLS(Cert.CLIENT_JKS, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_PKCS12, true, false, false, false, true); } @Test //Client specifies cert and it is required public void testTLSClientCertRequiredPEM() throws Exception { testTLS(Cert.CLIENT_JKS, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_PEM, true, false, false, false, true); } @Test //Client specifies cert and it is required public void testTLSClientCertPKCS12Required() throws Exception { testTLS(Cert.CLIENT_PKCS12, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_JKS, true, false, false, false, true); } @Test //Client specifies cert and it is required public void testTLSClientCertPEMRequired() throws Exception { testTLS(Cert.CLIENT_PEM, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_JKS, true, false, false, false, true); } @Test //Client specifies cert signed by CA and it is required public void testTLSClientCertPEM_CARequired() throws Exception { testTLS(Cert.CLIENT_PEM_ROOT_CA, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_PEM_ROOT_CA, true, false, false, false, true); } @Test //Client doesn't specify cert but it's required public void testTLSClientCertRequiredNoClientCert() throws Exception { testTLS(Cert.NONE, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_JKS, true, false, false, false, false); } @Test //Client specifies cert but it's not trusted public void testTLSClientCertClientNotTrusted() throws Exception { testTLS(Cert.CLIENT_JKS, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.NONE, true, false, false, false, false); } @Test // Server specifies cert that the client does not trust via a revoked certificate of the CA public void testTLSClientRevokedServerCert() throws Exception { testTLS(Cert.NONE, Trust.SERVER_PEM_ROOT_CA, Cert.SERVER_PEM_ROOT_CA, Trust.NONE, false, false, false, true, false); } @Test //Client specifies cert that the server does not trust via a revoked certificate of the CA public void testTLSRevokedClientCertServer() throws Exception { testTLS(Cert.CLIENT_PEM_ROOT_CA, Trust.SERVER_JKS, Cert.SERVER_JKS, Trust.CLIENT_PEM_ROOT_CA, true, true, false, false, false); } @Test // Test with cipher suites public void testTLSCipherSuites() throws Exception { testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, true, false, true, ENABLED_CIPHER_SUITES); } // RequestOptions tests @Test // Client trusts all server certs public void testClearClientRequestOptionsSetSSL() throws Exception { RequestOptions options = new RequestOptions().setHost(HttpTestBase.DEFAULT_HTTP_HOST).setURI("/").setPort(4043).setSsl(true); testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, true, false, true, false, true, false, new String[0], client -> client.websocketStream(options)); } @Test // Client trusts all server certs public void testSSLClientRequestOptionsSetSSL() throws Exception { RequestOptions options = new RequestOptions().setHost(HttpTestBase.DEFAULT_HTTP_HOST).setURI("/").setPort(4043).setSsl(true); testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, true, false, true, true, true, false, new String[0], client -> client.websocketStream(options)); } @Test // Client trusts all server certs public void testClearClientRequestOptionsSetClear() throws Exception { RequestOptions options = new RequestOptions().setHost(HttpTestBase.DEFAULT_HTTP_HOST).setURI("/").setPort(4043).setSsl(false); testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, true, false, true, false, false, false, new String[0], client -> client.websocketStream(options)); } @Test // Client trusts all server certs public void testSSLClientRequestOptionsSetClear() throws Exception { RequestOptions options = new RequestOptions().setHost(HttpTestBase.DEFAULT_HTTP_HOST).setURI("/").setPort(4043).setSsl(false); testTLS(Cert.NONE, Trust.NONE, Cert.SERVER_JKS, Trust.NONE, false, false, true, false, true, true, false, false, new String[0], client -> client.websocketStream(options)); } private void testTLS(Cert<?> clientCert, Trust<?> clientTrust, Cert<?> serverCert, Trust<?> serverTrust, boolean requireClientAuth, boolean serverUsesCrl, boolean clientTrustAll, boolean clientUsesCrl, boolean shouldPass, String... enabledCipherSuites) throws Exception { testTLS(clientCert, clientTrust, serverCert, serverTrust, requireClientAuth, serverUsesCrl, clientTrustAll, clientUsesCrl, shouldPass, true, true, false, enabledCipherSuites, client -> client.websocketStream(4043, HttpTestBase.DEFAULT_HTTP_HOST, "/")); } private void testTLS(Cert<?> clientCert, Trust<?> clientTrust, Cert<?> serverCert, Trust<?> serverTrust, boolean requireClientAuth, boolean serverUsesCrl, boolean clientTrustAll, boolean clientUsesCrl, boolean shouldPass, boolean clientSsl, boolean serverSsl, boolean sni, String[] enabledCipherSuites, Function<HttpClient, ReadStream<WebSocket>> wsProvider) throws Exception { HttpClientOptions options = new HttpClientOptions(); options.setSsl(clientSsl); options.setTrustAll(clientTrustAll); if (clientUsesCrl) { options.addCrlPath("tls/root-ca/crl.pem"); } options.setTrustOptions(clientTrust.get()); options.setKeyCertOptions(clientCert.get()); for (String suite: enabledCipherSuites) { options.addEnabledCipherSuite(suite); } client = vertx.createHttpClient(options); HttpServerOptions serverOptions = new HttpServerOptions(); serverOptions.setSsl(serverSsl); serverOptions.setSni(sni); serverOptions.setTrustOptions(serverTrust.get()); serverOptions.setKeyCertOptions(serverCert.get()); if (requireClientAuth) { serverOptions.setClientAuth(ClientAuth.REQUIRED); } if (serverUsesCrl) { serverOptions.addCrlPath("tls/root-ca/crl.pem"); } for (String suite: enabledCipherSuites) { serverOptions.addEnabledCipherSuite(suite); } server = vertx.createHttpServer(serverOptions.setPort(4043)); server.websocketHandler(ws -> { ws.handler(ws::write); }); try { server.listen(ar -> { assertTrue(ar.succeeded()); Handler<Throwable> errHandler = t -> { if (shouldPass) { t.printStackTrace(); fail("Should not throw exception"); } else { testComplete(); } }; Handler<WebSocket> wsHandler = ws -> { if (clientSsl && sni) { try { X509Certificate clientPeerCert = ws.peerCertificateChain()[0]; assertEquals("host2.com", cnOf(clientPeerCert)); } catch (Exception err) { fail(err); } } int size = 100; Buffer received = Buffer.buffer(); ws.handler(data -> { received.appendBuffer(data); if (received.length() == size) { ws.close(); testComplete(); } }); Buffer buff = Buffer.buffer(TestUtils.randomByteArray(size)); ws.writeFrame(WebSocketFrame.binaryFrame(buff, true)); }; wsProvider.apply(client). exceptionHandler(errHandler). handler(wsHandler); }); } catch (Exception e) { e.printStackTrace(); } await(); } @Test // Let's manually handle the websocket handshake and write a frame to the client public void testHandleWSManually() throws Exception { String path = "/some/path"; String message = "here is some text data"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).requestHandler(req -> { NetSocket sock = getUpgradedNetSocket(req, path); // Let's write a Text frame raw Buffer buff = Buffer.buffer(); buff.appendByte((byte)129); // Text frame buff.appendByte((byte)message.length()); buff.appendString(message); sock.write(buff); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path). exceptionHandler(t -> fail(t.getMessage())). handler(ws -> { ws.handler(buff -> { assertEquals(message, buff.toString("UTF-8")); testComplete(); }); }); }); await(); } @Test public void testSharedServersRoundRobin() throws Exception { int numServers = 5; int numConnections = numServers * 100; List<HttpServer> servers = new ArrayList<>(); Set<HttpServer> connectedServers = new ConcurrentHashSet<>(); Map<HttpServer, Integer> connectCount = new ConcurrentHashMap<>(); CountDownLatch latchListen = new CountDownLatch(numServers); CountDownLatch latchConns = new CountDownLatch(numConnections); for (int i = 0; i < numServers; i++) { HttpServer theServer = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); servers.add(theServer); theServer.websocketHandler(ws -> { connectedServers.add(theServer); Integer cnt = connectCount.get(theServer); int icnt = cnt == null ? 0 : cnt; icnt++; connectCount.put(theServer, icnt); latchConns.countDown(); }).listen(ar -> { if (ar.succeeded()) { latchListen.countDown(); } else { fail("Failed to bind server"); } }); } assertTrue(latchListen.await(10, TimeUnit.SECONDS)); // Create a bunch of connections CountDownLatch latchClient = new CountDownLatch(numConnections); for (int i = 0; i < numConnections; i++) { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/someuri", ws -> { ws.closeHandler(v -> latchClient.countDown()); ws.close(); }); } assertTrue(latchClient.await(10, TimeUnit.SECONDS)); assertTrue(latchConns.await(10, TimeUnit.SECONDS)); assertEquals(numServers, connectedServers.size()); for (HttpServer server: servers) { assertTrue(connectedServers.contains(server)); } assertEquals(numServers, connectCount.size()); for (int cnt: connectCount.values()) { assertEquals(numConnections / numServers, cnt); } CountDownLatch closeLatch = new CountDownLatch(numServers); for (HttpServer server: servers) { server.close(ar -> { assertTrue(ar.succeeded()); closeLatch.countDown(); }); } assertTrue(closeLatch.await(10, TimeUnit.SECONDS)); testComplete(); } @Test public void testSharedServersRoundRobinWithOtherServerRunningOnDifferentPort() throws Exception { // Have a server running on a different port to make sure it doesn't interact CountDownLatch latch = new CountDownLatch(1); HttpServer theServer = vertx.createHttpServer(new HttpServerOptions().setPort(4321)); theServer.websocketHandler(ws -> { fail("Should not connect"); }).listen(ar -> { if (ar.succeeded()) { latch.countDown(); } else { fail("Failed to bind server"); } }); awaitLatch(latch); testSharedServersRoundRobin(); } @Test public void testSharedServersRoundRobinButFirstStartAndStopServer() throws Exception { // Start and stop a server on the same port/host before hand to make sure it doesn't interact CountDownLatch latch = new CountDownLatch(1); HttpServer theServer = vertx.createHttpServer(new HttpServerOptions().setPort(4321)); theServer.websocketHandler(ws -> { fail("Should not connect"); }).listen(ar -> { if (ar.succeeded()) { latch.countDown(); } else { fail("Failed to bind server"); } }); awaitLatch(latch); CountDownLatch closeLatch = new CountDownLatch(1); theServer.close(ar -> { assertTrue(ar.succeeded()); closeLatch.countDown(); }); assertTrue(closeLatch.await(10, TimeUnit.SECONDS)); testSharedServersRoundRobin(); } @Test public void testWebsocketFrameFactoryArguments() throws Exception { assertNullPointerException(() -> WebSocketFrame.binaryFrame(null, true)); assertNullPointerException(() -> WebSocketFrame.textFrame(null, true)); assertNullPointerException(() -> WebSocketFrame.continuationFrame(null, true)); } private String sha1(String s) { try { MessageDigest md = MessageDigest.getInstance("SHA1"); //Hash the data byte[] bytes = md.digest(s.getBytes("UTF-8")); return Base64.getEncoder().encodeToString(bytes); } catch (Exception e) { throw new InternalError("Failed to compute sha-1"); } } private NetSocket getUpgradedNetSocket(HttpServerRequest req, String path) { assertEquals(path, req.path()); assertEquals("upgrade", req.headers().get("Connection")); NetSocket sock = req.netSocket(); String secHeader = req.headers().get("Sec-WebSocket-Key"); String tmp = secHeader + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; String encoded = sha1(tmp); sock.write("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: upgrade\r\n" + "Sec-WebSocket-Accept: " + encoded + "\r\n" + "\r\n"); return sock; } private void testWSWriteStream(WebsocketVersion version) throws Exception { String path = "/some/path"; String query = "foo=bar&wibble=eek"; String uri = path + "?" + query; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> { assertEquals(uri, ws.uri()); assertEquals(path, ws.path()); assertEquals(query, ws.query()); assertEquals("upgrade", ws.headers().get("Connection")); ws.handler(data -> ws.write(data)); }); server.listen(ar -> { assertTrue(ar.succeeded()); int bsize = 100; int sends = 10; client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path + "?" + query, null, version, ws -> { final Buffer received = Buffer.buffer(); ws.handler(data -> { received.appendBuffer(data); if (received.length() == bsize * sends) { ws.close(); testComplete(); } }); final Buffer sent = Buffer.buffer(); for (int i = 0; i < sends; i++) { Buffer buff = Buffer.buffer(TestUtils.randomByteArray(bsize)); ws.write(buff); sent.appendBuffer(buff); } }); }); await(); } private void testWSFrames(boolean binary, WebsocketVersion version) throws Exception { String path = "/some/path"; String query = "foo=bar&wibble=eek"; String uri = path + "?" + query; // version 0 doesn't support continuations so we just send 1 frame per message int frames = version == WebsocketVersion.V00 ? 1: 10; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> { assertEquals(uri, ws.uri()); assertEquals(path, ws.path()); assertEquals(query, ws.query()); assertEquals("upgrade", ws.headers().get("Connection")); AtomicInteger count = new AtomicInteger(); ws.frameHandler(frame -> { if (count.get() == 0) { if (binary) { assertTrue(frame.isBinary()); assertFalse(frame.isText()); } else { assertFalse(frame.isBinary()); assertTrue(frame.isText()); } assertFalse(frame.isContinuation()); } else { assertFalse(frame.isBinary()); assertFalse(frame.isText()); assertTrue(frame.isContinuation()); } if (count.get() == frames - 1) { assertTrue(frame.isFinal()); } else { assertFalse(frame.isFinal()); } ws.writeFrame(frame); if (count.incrementAndGet() == frames) { count.set(0); } }); }); server.listen(ar -> { assertTrue(ar.succeeded()); int bsize = 100; int msgs = 10; client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path + "?" + query, null, version, ws -> { final List<Buffer> sent = new ArrayList<>(); final List<Buffer> received = new ArrayList<>(); AtomicReference<Buffer> currentReceived = new AtomicReference<>(Buffer.buffer()); ws.frameHandler(frame -> { //received.appendBuffer(frame.binaryData()); currentReceived.get().appendBuffer(frame.binaryData()); if (frame.isFinal()) { received.add(currentReceived.get()); currentReceived.set(Buffer.buffer()); } if (received.size() == msgs) { int pos = 0; for (Buffer rec : received) { assertEquals(rec, sent.get(pos++)); } testComplete(); } }); AtomicReference<Buffer> currentSent = new AtomicReference<>(Buffer.buffer()); for (int i = 0; i < msgs; i++) { for (int j = 0; j < frames; j++) { Buffer buff; WebSocketFrame frame; if (binary) { buff = Buffer.buffer(TestUtils.randomByteArray(bsize)); if (j == 0) { frame = WebSocketFrame.binaryFrame(buff, false); } else { frame = WebSocketFrame.continuationFrame(buff, j == frames - 1); } } else { String str = TestUtils.randomAlphaString(bsize); buff = Buffer.buffer(str); if (j == 0) { frame = WebSocketFrame.textFrame(str, false); } else { frame = WebSocketFrame.continuationFrame(buff, j == frames - 1); } } currentSent.get().appendBuffer(buff); ws.writeFrame(frame); if (j == frames - 1) { sent.add(currentSent.get()); currentSent.set(Buffer.buffer()); } } } }); }); await(); } @Test public void testWriteFinalTextFrame() throws Exception { testWriteFinalFrame(false); } @Test public void testWriteFinalBinaryFrame() throws Exception { testWriteFinalFrame(true); } private void testWriteFinalFrame(boolean binary) throws Exception { String text = TestUtils.randomUnicodeString(100); Buffer data = TestUtils.randomBuffer(100); Consumer<WebSocketFrame> frameConsumer = frame -> { if (binary) { assertTrue(frame.isBinary()); assertFalse(frame.isText()); assertEquals(data, frame.binaryData()); } else { assertFalse(frame.isBinary()); assertTrue(frame.isText()); assertEquals(text, frame.textData()); } assertTrue(frame.isFinal()); }; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> ws.frameHandler(frame -> { frameConsumer.accept(frame); if (binary) { ws.writeFinalBinaryFrame(frame.binaryData()); } else { ws.writeFinalTextFrame(frame.textData()); } }) ); server.listen(onSuccess(s -> client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", ws -> { ws.frameHandler(frame -> { frameConsumer.accept(frame); testComplete(); }); if (binary) { ws.writeFinalBinaryFrame(data); } else { ws.writeFinalTextFrame(text); } }) )); await(); } private void testContinuationWriteFromConnectHandler(WebsocketVersion version) throws Exception { String path = "/some/path"; String firstFrame = "AAA"; String continuationFrame = "BBB"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).requestHandler(req -> { NetSocket sock = getUpgradedNetSocket(req, path); // Let's write a Text frame raw Buffer buff = Buffer.buffer(); buff.appendByte((byte) 0x01); // Incomplete Text frame buff.appendByte((byte) firstFrame.length()); buff.appendString(firstFrame); sock.write(buff); buff = Buffer.buffer(); buff.appendByte((byte) (0x00 | 0x80)); // Complete continuation frame buff.appendByte((byte) continuationFrame.length()); buff.appendString(continuationFrame); sock.write(buff); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version, ws -> { AtomicBoolean receivedFirstFrame = new AtomicBoolean(); ws.frameHandler(received -> { Buffer receivedBuffer = Buffer.buffer(received.textData()); if (!received.isFinal()) { assertEquals(firstFrame, receivedBuffer.toString()); receivedFirstFrame.set(true); } else if (receivedFirstFrame.get() && received.isFinal()) { assertEquals(continuationFrame, receivedBuffer.toString()); ws.close(); testComplete(); } }); }); }); await(); } private void testWriteFromConnectHandler(WebsocketVersion version) throws Exception { String path = "/some/path"; Buffer buff = Buffer.buffer("AAA"); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> { assertEquals(path, ws.path()); ws.writeFrame(WebSocketFrame.binaryFrame(buff, true)); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version, ws -> { Buffer received = Buffer.buffer(); ws.handler(data -> { received.appendBuffer(data); if (received.length() == buff.length()) { assertEquals(buff, received); ws.close(); testComplete(); } }); }); }); await(); } @Test public void testWriteFromConnectHandlerFromAnotherThread() { Buffer expected = Buffer.buffer("AAA"); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); server.websocketHandler(ws -> { Thread t = new Thread() { @Override public void run() { ws.writeFrame(WebSocketFrame.binaryFrame(expected, true)); } }; t.start(); while (t.getState() != Thread.State.BLOCKED) { Thread.yield(); } }); server.listen(onSuccess(server -> { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", ws -> { ws.handler(buff -> { assertEquals(buff, expected); testComplete(); }); }); })); await(); } private void testValidSubProtocol(WebsocketVersion version) throws Exception { String path = "/some/path"; String subProtocol = "myprotocol"; Buffer buff = Buffer.buffer("AAA"); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT).setWebsocketSubProtocols(subProtocol)).websocketHandler(ws -> { assertEquals(path, ws.path()); ws.writeFrame(WebSocketFrame.binaryFrame(buff, true)); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version, subProtocol, ws -> { final Buffer received = Buffer.buffer(); ws.handler(data -> { received.appendBuffer(data); if (received.length() == buff.length()) { assertEquals(buff, received); ws.close(); testComplete(); } }); }); }); await(); } private void testInvalidSubProtocol(WebsocketVersion version) throws Exception { String path = "/some/path"; String subProtocol = "myprotocol"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT).setWebsocketSubProtocols("invalid")).websocketHandler(ws -> { }); server.listen(onSuccess(ar -> { client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version, subProtocol). exceptionHandler(t -> { // Should fail testComplete(); }). handler(ws -> { }); })); await(); } private void testReject(WebsocketVersion version) throws Exception { String path = "/some/path"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> { assertEquals(path, ws.path()); ws.reject(); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version). exceptionHandler(t -> testComplete()). handler(ws -> fail("Should not be called")); }); await(); } @Test public void testWriteMessageHybi00() { testWriteMessage(256, WebsocketVersion.V00); } @Test public void testWriteFragmentedMessage1Hybi00() { testWriteMessage(65536 + 256, WebsocketVersion.V00); } @Test public void testWriteFragmentedMessage2Hybi00() { testWriteMessage(65536 + 65536 + 256, WebsocketVersion.V00); } @Test public void testWriteMessageHybi08() { testWriteMessage(256, WebsocketVersion.V08); } @Test public void testWriteFragmentedMessage1Hybi08() { testWriteMessage(65536 + 256, WebsocketVersion.V08); } @Test public void testWriteFragmentedMessage2Hybi08() { testWriteMessage(65536 + 65536 + 256, WebsocketVersion.V08); } @Test public void testWriteMessageHybi17() { testWriteMessage(256, WebsocketVersion.V13); } @Test public void testWriteFragmentedMessage1Hybi17() { testWriteMessage(65536 + 256, WebsocketVersion.V13); } @Test public void testWriteFragmentedMessage2Hybi17() { testWriteMessage(65536 + 65536 + 256, WebsocketVersion.V13); } private void testWriteMessage(int size, WebsocketVersion version) { String path = "/some/path"; byte[] expected = TestUtils.randomByteArray(size); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> { ws.writeBinaryMessage(Buffer.buffer(expected)); ws.close(); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version, ws -> { Buffer actual = Buffer.buffer(); ws.handler(actual::appendBuffer); ws.closeHandler(v -> { assertArrayEquals(expected, actual.getBytes()); testComplete(); }); }); }); await(); } @Test public void testNonFragmentedTextMessage2Hybi00() { String messageToSend = TestUtils.randomAlphaString(256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V00); } @Test public void testFragmentedTextMessage2Hybi07() { String messageToSend = TestUtils.randomAlphaString(65536 + 65536 + 256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V07); } @Test public void testFragmentedTextMessage2Hybi08() { String messageToSend = TestUtils.randomAlphaString(65536 + 65536 + 256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V08); } @Test public void testFragmentedTextMessage2Hybi13() { String messageToSend = TestUtils.randomAlphaString(65536 + 65536 + 256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V13); } @Test public void testMaxLengthFragmentedTextMessage() { String messageToSend = TestUtils.randomAlphaString(HttpServerOptions.DEFAULT_MAX_WEBSOCKET_MESSAGE_SIZE); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V13); } @Test public void testFragmentedUnicodeTextMessage2Hybi07() { String messageToSend = TestUtils.randomUnicodeString(65536 + 256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V07); } @Test public void testFragmentedUnicodeTextMessage2Hybi08() { String messageToSend = TestUtils.randomUnicodeString(65536 + 256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V08); } @Test public void testFragmentedUnicodeTextMessage2Hybi13() { String messageToSend = TestUtils.randomUnicodeString(65536 + 256); testWriteSingleTextMessage(messageToSend, WebsocketVersion.V13); } @Test public void testTooLargeMessage() { String messageToSend = TestUtils.randomAlphaString(HttpClientOptions.DEFAULT_MAX_WEBSOCKET_MESSAGE_SIZE + 1); SocketMessages socketMessages = testWriteTextMessages(Collections.singletonList(messageToSend), WebsocketVersion.V13); List<String> receivedMessages = socketMessages.getReceivedMessages(); List<String> expectedMessages = Collections.emptyList(); assertEquals("Should not have received any messages", expectedMessages, receivedMessages); List<Throwable> receivedExceptions = socketMessages.getReceivedExceptions(); assertEquals("Should have received a single exception", 1, receivedExceptions.size()); assertTrue("Should have received IllegalStateException", receivedExceptions.get(0) instanceof IllegalStateException); } @Test public void testContinueAfterTooLargeMessage() { int shortMessageLength = HttpClientOptions.DEFAULT_MAX_WEBSOCKET_FRAME_SIZE; String shortFirstMessage = TestUtils.randomAlphaString(shortMessageLength); String tooLongMiddleMessage = TestUtils.randomAlphaString(HttpClientOptions.DEFAULT_MAX_WEBSOCKET_MESSAGE_SIZE * 2); String shortLastMessage = TestUtils.randomAlphaString(shortMessageLength); List<String> messagesToSend = Arrays.asList(shortFirstMessage, tooLongMiddleMessage, shortLastMessage); SocketMessages socketMessages = testWriteTextMessages(messagesToSend, WebsocketVersion.V13); List<String> receivedMessages = socketMessages.getReceivedMessages(); List<String> expectedMessages = Arrays.asList(shortFirstMessage, shortLastMessage); assertEquals("Incorrect received messages", expectedMessages, receivedMessages); } private void testWriteSingleTextMessage(String messageToSend, WebsocketVersion version) { List<String> messagesToSend = Collections.singletonList(messageToSend); SocketMessages socketMessages = testWriteTextMessages(messagesToSend, version); assertEquals("Did not receive all messages", messagesToSend, socketMessages.getReceivedMessages()); List<Throwable> expectedExceptions = Collections.emptyList(); assertEquals("Should not have received any exceptions", expectedExceptions, socketMessages.getReceivedExceptions()); } private SocketMessages testWriteTextMessages(List<String> messagesToSend, WebsocketVersion version) { String path = "/some/path"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(ws -> { for (String messageToSend : messagesToSend) { ws.writeTextMessage(messageToSend); } ws.close(); }); List<String> receivedMessages = new ArrayList<>(); List<Throwable> receivedExceptions = new ArrayList<>(); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null, version, ws -> { ws.textMessageHandler(receivedMessages::add); ws.exceptionHandler(receivedExceptions::add); ws.closeHandler(v -> testComplete()); }); }); await(); return new SocketMessages(receivedMessages, receivedExceptions); } private static class SocketMessages { private final List<String> receivedMessages; private final List<Throwable> receivedExceptions; public SocketMessages(List<String> receivedMessages, List<Throwable> receivedExceptions) { this.receivedMessages = receivedMessages; this.receivedExceptions = receivedExceptions; } public List<String> getReceivedMessages() { return receivedMessages; } public List<Throwable> getReceivedExceptions() { return receivedExceptions; } } @Test public void testWebsocketPauseAndResume() { client.close(); client = vertx.createHttpClient(new HttpClientOptions().setConnectTimeout(1000)); this.server = vertx.createHttpServer(new HttpServerOptions().setAcceptBacklog(1).setPort(HttpTestBase.DEFAULT_HTTP_PORT)); AtomicBoolean paused = new AtomicBoolean(); ReadStream<ServerWebSocket> stream = server.websocketStream(); stream.handler(ws -> { assertFalse(paused.get()); ws.writeBinaryMessage(Buffer.buffer("whatever")); ws.close(); }); server.listen(listenAR -> { assertTrue(listenAR.succeeded()); stream.pause(); paused.set(true); connectUntilWebsocketHandshakeException(client, 0, res -> { if (!res.succeeded()) { fail(new AssertionError("Was expecting error to be WebSocketHandshakeException", res.cause())); } assertTrue(paused.get()); paused.set(false); stream.resume(); client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/some/path", ws -> { ws.handler(buffer -> { assertEquals("whatever", buffer.toString("UTF-8")); ws.closeHandler(v2 -> { testComplete(); }); }); }); }); }); await(); } private void connectUntilWebsocketHandshakeException(HttpClient client, int count, Handler<AsyncResult<Void>> doneHandler) { vertx.runOnContext(v -> { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/some/path", ws -> { if (count < 100) { connectUntilWebsocketHandshakeException(client, count + 1, doneHandler); } else { doneHandler.handle(Future.failedFuture(new AssertionError())); } }, err -> { if (err instanceof WebSocketHandshakeException) { doneHandler.handle(Future.succeededFuture()); } else if (count < 100) { connectUntilWebsocketHandshakeException(client, count + 1, doneHandler); } else { doneHandler.handle(Future.failedFuture(err)); } }); }); } @Test public void testClosingServerClosesWebSocketStreamEndHandler() { this.server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); ReadStream<ServerWebSocket> stream = server.websocketStream(); AtomicBoolean closed = new AtomicBoolean(); stream.endHandler(v -> closed.set(true)); stream.handler(ws -> { }); server.listen(ar -> { assertTrue(ar.succeeded()); assertFalse(closed.get()); server.close(v -> { assertTrue(ar.succeeded()); assertTrue(closed.get()); testComplete(); }); }); await(); } @Test public void testWebsocketStreamCallbackAsynchronously() { this.server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); AtomicInteger done = new AtomicInteger(); ReadStream<ServerWebSocket> stream = server.websocketStream(); stream.handler(req -> { }); ThreadLocal<Object> stack = new ThreadLocal<>(); stack.set(true); stream.endHandler(v -> { assertTrue(Vertx.currentContext().isEventLoopContext()); assertNull(stack.get()); if (done.incrementAndGet() == 2) { testComplete(); } }); server.listen(ar -> { assertTrue(Vertx.currentContext().isEventLoopContext()); assertNull(stack.get()); ThreadLocal<Object> stack2 = new ThreadLocal<>(); stack2.set(true); server.close(v -> { assertTrue(Vertx.currentContext().isEventLoopContext()); assertNull(stack2.get()); if (done.incrementAndGet() == 2) { testComplete(); } }); stack2.set(null); }); await(); } @Test public void testMultipleServerClose() { this.server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); AtomicInteger times = new AtomicInteger(); // We assume the endHandler and the close completion handler are invoked in the same context task ThreadLocal stack = new ThreadLocal(); stack.set(true); server.websocketStream().endHandler(v -> { assertNull(stack.get()); assertTrue(Vertx.currentContext().isEventLoopContext()); times.incrementAndGet(); }); server.close(ar1 -> { assertNull(stack.get()); assertTrue(Vertx.currentContext().isEventLoopContext()); server.close(ar2 -> { server.close(ar3 -> { assertEquals(1, times.get()); testComplete(); }); }); }); await(); } @Test public void testEndHandlerCalled() { String path = "/some/path"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(WebSocketBase::close); AtomicInteger doneCount = new AtomicInteger(); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null). endHandler(done -> doneCount.incrementAndGet()). handler(ws -> { assertEquals(0, doneCount.get()); ws.closeHandler(v -> { assertEquals(1, doneCount.get()); testComplete(); }); }); }); await(); } @Test public void testClearClientHandlersOnEnd() { String path = "/some/path"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)).websocketHandler(WebSocketBase::close); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null). handler(ws -> { ws.endHandler(v -> { try { ws.endHandler(null); ws.exceptionHandler(null); ws.handler(null); } catch (Exception e) { fail("Was expecting to set to null the handlers when the socket is closed"); return; } testComplete(); }); }); }); await(); } @Test public void testUpgrade() { testUpgrade(false); } @Test public void testUpgradeDelayed() { testUpgrade(true); } private void testUpgrade(boolean delayed) { String path = "/some/path"; server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); server.requestHandler(request -> { Runnable runner = () -> { ServerWebSocket ws = request.upgrade(); ws.handler(buff -> { ws.write(Buffer.buffer("helloworld")); ws.close(); }); }; if (delayed) { // This tests the case where the last http content comes of the request (its not full) comes in // before the upgrade has happened and before HttpServerImpl.expectWebsockets is true vertx.runOnContext(v -> { runner.run(); }); } else { runner.run(); } }); server.listen(ar -> { assertTrue(ar.succeeded()); client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, null). handler(ws -> { Buffer buff = Buffer.buffer(); ws.handler(b -> { buff.appendBuffer(b); }); ws.endHandler(v -> { assertEquals("helloworld", buff.toString()); testComplete(); }); ws.write(Buffer.buffer("foo")); }); }); await(); } @Test public void testUnmaskedFrameRequest(){ client = vertx.createHttpClient(new HttpClientOptions().setSendUnmaskedFrames(true)); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT).setAcceptUnmaskedFrames(true)); server.requestHandler(req -> { req.response().setChunked(true).write("connect"); }); server.websocketHandler(ws -> { ws.handler(new Handler<Buffer>() { public void handle(Buffer data) { assertEquals(data.toString(), "first unmasked frame"); testComplete(); } }); }); server.listen(onSuccess(server -> { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", ws -> { ws.writeFinalTextFrame("first unmasked frame"); }); })); await(); } @Test public void testInvalidUnmaskedFrameRequest(){ client = vertx.createHttpClient(new HttpClientOptions().setSendUnmaskedFrames(true)); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); server.requestHandler(req -> { req.response().setChunked(true).write("connect"); }); server.websocketHandler(ws -> { ws.exceptionHandler(exception -> { testComplete(); }); ws.handler(result -> { fail("Cannot decode unmasked message because I require masked frame as configured"); }); }); server.listen(onSuccess(server -> { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", ws -> { ws.writeFinalTextFrame("first unmasked frame"); }); })); await(); } @Test public void testUpgradeInvalidRequest() { server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); server.requestHandler(request -> { try { request.upgrade(); fail("Should throw exception"); } catch (IllegalStateException e) { // OK } testComplete(); }); server.listen(ar -> { assertTrue(ar.succeeded()); client.request(HttpMethod.GET, HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", resp -> { }).end(); }); await(); } @Test public void testRaceConditionWithWebsocketClientEventLoop() { testRaceConditionWithWebsocketClient(vertx.getOrCreateContext()); } @Test public void testRaceConditionWithWebsocketClientWorker() throws Exception { CompletableFuture<Context> fut = new CompletableFuture<>(); vertx.deployVerticle(new AbstractVerticle() { @Override public void start() throws Exception { fut.complete(context); } }, new DeploymentOptions().setWorker(true), ar -> { if (ar.failed()) { fut.completeExceptionally(ar.cause()); } }); testRaceConditionWithWebsocketClient(fut.get()); } private void testRaceConditionWithWebsocketClient(Context context) { server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); // Handcrafted websocket handshake for sending a frame immediatly after the handshake server.requestHandler(req -> { byte[] accept; try { MessageDigest digest = MessageDigest.getInstance("SHA-1"); byte[] inputBytes = (req.getHeader("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes(); digest.update(inputBytes); byte[] hashedBytes = digest.digest(); accept = Base64.getEncoder().encode(hashedBytes); } catch (NoSuchAlgorithmException e) { fail(e.getMessage()); return; } NetSocket so = req.netSocket(); Buffer data = Buffer.buffer(); data.appendString("HTTP/1.1 101 Switching Protocols\r\n"); data.appendString("Upgrade: websocket\r\n"); data.appendString("Connection: upgrade\r\n"); data.appendString("Sec-WebSocket-Accept: " + new String(accept) + "\r\n"); data.appendString("\r\n"); data.appendBytes(new byte[]{ (byte) 0x82, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, }); so.write(data); }); server.listen(ar -> { assertTrue(ar.succeeded()); context.runOnContext(v -> { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", ws -> { ws.handler(buf -> { assertEquals("hello", buf.toString()); testComplete(); }); }); }); }); await(); } @Test public void testRaceConditionWithWebsocketClientWorker2() throws Exception { int size = getOptions().getWorkerPoolSize() - 4; List<Context> workers = createWorkers(size + 1); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); server.websocketHandler(ws -> { ws.write(Buffer.buffer("hello")); }); server.listen(ar -> { assertTrue(ar.succeeded()); workers.get(0).runOnContext(v -> { ReadStream<WebSocket> webSocketStream = client.websocketStream(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/"); webSocketStream.handler(ws -> { ws.handler(buf -> { assertEquals("hello", buf.toString()); testComplete(); }); }); }); }); await(); } @Test public void httpClientWebsocketConnectionFailureHandlerShouldBeCalled() throws Exception { String nonExistingHost = "idont.even.exist"; int port = 7867; HttpClient client = vertx.createHttpClient(); client.websocket(port, nonExistingHost, "", websocket -> { websocket.handler(data -> { fail("connection should not succeed"); }); }, throwable -> testComplete()); await(); } @Test public void testClientWebsocketWithHttp2Client() throws Exception { client.close(); client = vertx.createHttpClient(new HttpClientOptions().setHttp2ClearTextUpgrade(false).setProtocolVersion(HttpVersion.HTTP_2)); server = vertx.createHttpServer(new HttpServerOptions().setPort(HttpTestBase.DEFAULT_HTTP_PORT)); server.requestHandler(req -> { req.response().setChunked(true).write("connect"); }); server.websocketHandler(ws -> { ws.writeFinalTextFrame("ok"); }); server.listen(onSuccess(server -> { client.getNow(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", resp -> { client.websocket(HttpTestBase.DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/", ws -> { ws.handler(buff -> { assertEquals("ok", buff.toString()); testComplete(); }); }); }); })); await(); } @Test public void testClientWebsocketConnectionCloseOnBadResponseWithKeepalive() throws Throwable { // issue #1757 doTestClientWebsocketConnectionCloseOnBadResponse(true); } @Test public void testClientWebsocketConnectionCloseOnBadResponseWithoutKeepalive() throws Throwable { doTestClientWebsocketConnectionCloseOnBadResponse(false); } final BlockingQueue<Throwable> resultQueue = new ArrayBlockingQueue<Throwable>(10); void addResult(Throwable result) { try { resultQueue.put(result); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } private void doTestClientWebsocketConnectionCloseOnBadResponse(boolean keepAliveInOptions) throws Throwable { final Exception serverGotCloseException = new Exception(); netServer = vertx.createNetServer().connectHandler(sock -> { final Buffer fullReq = Buffer.buffer(230); sock.handler(b -> { fullReq.appendBuffer(b); String reqPart = b.toString(); if (fullReq.toString().contains("\r\n\r\n")) { try { String content = "0123456789"; content = content + content; content = content + content + content + content + content; String resp = "HTTP/1.1 200 OK\r\n"; if (keepAliveInOptions) { resp += "Connection: close\r\n"; } resp += "Content-Length: 100\r\n\r\n" + content; sock.write(Buffer.buffer(resp.getBytes("ASCII"))); } catch (UnsupportedEncodingException e) { addResult(e); } } }); sock.closeHandler(v -> { addResult(serverGotCloseException); }); }).listen(ar -> { if (ar.failed()) { addResult(ar.cause()); return; } NetServer server = ar.result(); int port = server.actualPort(); HttpClientOptions opts = new HttpClientOptions().setKeepAlive(keepAliveInOptions); client.close(); client = vertx.createHttpClient(opts).websocket(port, "localhost", "/", ws -> { addResult(new AssertionError("Websocket unexpectedly connected")); ws.close(); }, t -> { addResult(t); }); }); boolean serverGotClose = false; boolean clientGotCorrectException = false; while (!serverGotClose || !clientGotCorrectException) { Throwable result = resultQueue.poll(20, TimeUnit.SECONDS); if (result == null) { throw new AssertionError("Timed out waiting for expected state, current: serverGotClose = " + serverGotClose + ", clientGotCorrectException = " + clientGotCorrectException); } else if (result == serverGotCloseException) { serverGotClose = true; } else if (result instanceof WebSocketHandshakeException && result.getMessage().equals("Websocket connection attempt returned HTTP status code 200")) { clientGotCorrectException = true; } else { throw result; } } } }