/* * Copyright (c) 2008-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.cometd.websocket.server; import java.io.IOException; import java.io.OutputStream; import java.net.HttpCookie; import java.net.Socket; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import javax.websocket.HandshakeResponse; import javax.websocket.server.HandshakeRequest; import org.cometd.bayeux.Channel; import org.cometd.bayeux.server.BayeuxContext; import org.cometd.bayeux.server.BayeuxServer; import org.cometd.bayeux.server.ServerChannel; import org.cometd.bayeux.server.ServerMessage; import org.cometd.bayeux.server.ServerSession; import org.cometd.client.BayeuxClient; import org.cometd.server.BayeuxServerImpl; import org.cometd.websocket.ClientServerWebSocketTest; import org.eclipse.jetty.client.api.ContentResponse; import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertTrue; public class BayeuxContextTest extends ClientServerWebSocketTest { public BayeuxContextTest(String implementation) { super(implementation); } @Test public void testRequestHeaderIsCaseInsensitive() throws Exception { prepareAndStart(null); final CountDownLatch latch = new CountDownLatch(1); bayeux.getChannel(Channel.META_HANDSHAKE).addListener(new ServerChannel.MessageListener() { @Override public boolean onMessage(ServerSession from, ServerChannel channel, ServerMessage.Mutable message) { BayeuxContext context = bayeux.getContext(); Assert.assertEquals(context.getHeader("Host"), context.getHeader("HOST")); Assert.assertEquals(context.getHeader("Host"), context.getHeaderValues("HOST").get(0)); latch.countDown(); return true; } }); BayeuxClient client = newBayeuxClient(); client.handshake(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Test public void testCookiesSentToServer() throws Exception { prepareAndStart(null); final String cookieName = "name"; final String cookieValue = "value"; final CountDownLatch latch = new CountDownLatch(1); bayeux.getChannel(Channel.META_HANDSHAKE).addListener(new ServerChannel.MessageListener() { @Override public boolean onMessage(ServerSession from, ServerChannel channel, ServerMessage.Mutable message) { BayeuxContext context = bayeux.getContext(); Assert.assertEquals(cookieValue, context.getCookie(cookieName)); latch.countDown(); return true; } }); BayeuxClient client = newBayeuxClient(); client.putCookie(new HttpCookie(cookieName, cookieValue)); client.handshake(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Test public void testCookiesSentToClient() throws Exception { String wsTransportClass; switch (wsTransportType) { case WEBSOCKET_JSR_356: wsTransportClass = CookieWebSocketTransport.class.getName(); break; case WEBSOCKET_JETTY: wsTransportClass = CookieJettyWebSocketTransport.class.getName(); break; default: throw new IllegalArgumentException(); } prepareServer(0, "/cometd", null, true, wsTransportClass); startServer(); prepareClient(); startClient(); BayeuxClient client = newBayeuxClient(); client.handshake(); Assert.assertTrue(client.waitFor(5000, BayeuxClient.State.CONNECTED)); HttpCookie cookie = client.getCookie(CookieConstants.COOKIE_NAME); Assert.assertEquals(CookieConstants.COOKIE_VALUE, cookie.getValue()); disconnectBayeuxClient(client); } @Test public void testMultipleCookiesSentToServer() throws Exception { prepareAndStart(null); final List<String> cookieNames = Arrays.asList("a", "BAYEUX_BROWSER", "b"); final List<String> cookieValues = Arrays.asList("1", "761e1pplr7yo3wmsri1x5y0gnnby", "2"); StringBuilder cookies = new StringBuilder(); for (int i = 0; i < cookieNames.size(); ++i) { cookies.append(cookieNames.get(i)).append("=").append(cookieValues.get(i)).append("; "); } final CountDownLatch latch = new CountDownLatch(1); bayeux.getChannel(Channel.META_HANDSHAKE).addListener(new ServerChannel.MessageListener() { @Override public boolean onMessage(ServerSession from, ServerChannel channel, ServerMessage.Mutable message) { BayeuxContext context = bayeux.getContext(); for (int i = 0; i < cookieNames.size(); ++i) { Assert.assertEquals(cookieValues.get(i), context.getCookie(cookieNames.get(i))); } latch.countDown(); return true; } }); try (Socket socket = new Socket("localhost", connector.getLocalPort())) { OutputStream output = socket.getOutputStream(); String upgrade = "" + "GET " + cometdServletPath + " HTTP/1.1\r\n" + "Host: localhost:" + connector.getLocalPort() + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Cookie: " + cookies + "\r\n" + "\r\n"; output.write(upgrade.getBytes(StandardCharsets.UTF_8)); output.flush(); // Wait for the upgrade to take place on server side. Thread.sleep(1000); String handshake = "" + "{" + "\"id\":\"1\"," + "\"channel\":\"/meta/handshake\"," + "\"version\":\"1.0\"," + "\"supportedConnectionTypes\":[\"websocket\"]" + "}"; byte[] handshakeBytes = handshake.getBytes(StandardCharsets.UTF_8); Assert.assertTrue(handshakeBytes.length <= 125); // Max payload length output.write(0x81); // FIN FLAG + TYPE=TEXT output.write(0x80 + handshakeBytes.length); // MASK FLAG + LENGTH output.write(new byte[]{0, 0, 0, 0}); // MASK BYTES output.write(handshakeBytes); // PAYLOAD output.flush(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); } } @Test public void testSessionAttribute() throws Exception { String wsTransportClass; switch (wsTransportType) { case WEBSOCKET_JSR_356: wsTransportClass = SessionWebSocketTransport.class.getName(); break; case WEBSOCKET_JETTY: wsTransportClass = SessionJettyWebSocketTransport.class.getName(); break; default: throw new IllegalArgumentException(); } prepareServer(0, "/cometd", null, true, wsTransportClass); context.addServlet(new ServletHolder(new HttpServlet() { @Override protected void service(HttpServletRequest request, HttpServletResponse resp) throws ServletException, IOException { HttpSession session = request.getSession(true); session.setAttribute(SessionConstants.ATTRIBUTE_NAME, SessionConstants.ATTRIBUTE_VALUE); } }), "/session"); startServer(); prepareClient(); startClient(); // Make an HTTP request to prime the HttpSession URI uri = URI.create("http://localhost:" + connector.getLocalPort() + "/session"); ContentResponse response = httpClient.newRequest(uri) .path("/session") .timeout(5, TimeUnit.SECONDS) .send(); Assert.assertEquals(200, response.getStatus()); List<HttpCookie> cookies = httpClient.getCookieStore().get(uri); Assert.assertNotNull(cookies); HttpCookie sessionCookie = null; for (HttpCookie cookie : cookies) { if ("jsessionid".equalsIgnoreCase(cookie.getName())) { sessionCookie = cookie; } } Assert.assertNotNull(sessionCookie); final CountDownLatch latch = new CountDownLatch(1); bayeux.addListener(new BayeuxServer.SessionListener() { @Override public void sessionAdded(ServerSession session, ServerMessage message) { Assert.assertNotNull(bayeux.getContext().getHttpSessionId()); Assert.assertEquals(SessionConstants.ATTRIBUTE_VALUE, bayeux.getContext().getHttpSessionAttribute(SessionConstants.ATTRIBUTE_NAME)); latch.countDown(); } @Override public void sessionRemoved(ServerSession session, boolean timedout) { } }); BayeuxClient client = newBayeuxClient(); client.getCookieStore().add(uri, sessionCookie); client.handshake(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); Assert.assertTrue(client.waitFor(5000, BayeuxClient.State.CONNECTED)); disconnectBayeuxClient(client); } @Test public void testContextAttribute() throws Exception { prepareAndStart(null); final CountDownLatch latch = new CountDownLatch(1); bayeux.addListener(new BayeuxServer.SessionListener() { @Override public void sessionAdded(ServerSession session, ServerMessage message) { Assert.assertSame(bayeux, bayeux.getContext().getContextAttribute(BayeuxServer.ATTRIBUTE)); latch.countDown(); } @Override public void sessionRemoved(ServerSession session, boolean timedout) { } }); BayeuxClient client = newBayeuxClient(); client.handshake(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); Assert.assertTrue(client.waitFor(5000, BayeuxClient.State.CONNECTED)); disconnectBayeuxClient(client); } @Test public void testConcurrentClientsHaveDifferentBayeuxContexts() throws Exception { String wsTransportClass; switch (wsTransportType) { case WEBSOCKET_JSR_356: wsTransportClass = ConcurrentBayeuxContextWebSocketTransport.class.getName(); break; case WEBSOCKET_JETTY: wsTransportClass = ConcurrentBayeuxContextJettyWebSocketTransport.class.getName(); break; default: throw new IllegalArgumentException(); } prepareServer(0, "/cometd", null, true, wsTransportClass); startServer(); prepareClient(); startClient(); // The first client will be held by the server. final BayeuxClient client1 = newBayeuxClient(); // The connect operation is blocking, so we must use another thread. new Thread(new Runnable() { @Override public void run() { client1.handshake(); } }).start(); // Wait for the first client to arrive at the concurrency point. switch (wsTransportType) { case WEBSOCKET_JSR_356: { CountDownLatch enterLatch = ((ConcurrentBayeuxContextWebSocketTransport)bayeux.getTransport("websocket")).enterLatch; assertTrue(enterLatch.await(5, TimeUnit.SECONDS)); break; } case WEBSOCKET_JETTY: { CountDownLatch enterLatch = ((ConcurrentBayeuxContextJettyWebSocketTransport)bayeux.getTransport("websocket")).enterLatch; assertTrue(enterLatch.await(5, TimeUnit.SECONDS)); break; } default: throw new IllegalArgumentException(); } // Connect the second client. BayeuxClient client2 = newBayeuxClient(); client2.handshake(); assertTrue(client2.waitFor(1000, BayeuxClient.State.CONNECTED)); // Release the first client. switch (wsTransportType) { case WEBSOCKET_JSR_356: ((ConcurrentBayeuxContextWebSocketTransport)bayeux.getTransport("websocket")).proceedLatch.countDown(); break; case WEBSOCKET_JETTY: ((ConcurrentBayeuxContextJettyWebSocketTransport)bayeux.getTransport("websocket")).proceedLatch.countDown(); break; default: throw new IllegalArgumentException(); } assertTrue(client1.waitFor(1000, BayeuxClient.State.CONNECTED)); final String channelName = "/service/test"; final Map<String, BayeuxContext> contexts = new ConcurrentHashMap<>(); final CountDownLatch contextLatch = new CountDownLatch(2); bayeux.createChannelIfAbsent(channelName).getReference().addListener(new ServerChannel.MessageListener() { @Override public boolean onMessage(ServerSession from, ServerChannel channel, ServerMessage.Mutable message) { contexts.put(from.getId(), bayeux.getContext()); contextLatch.countDown(); return true; } }); client1.getChannel(channelName).publish("data"); client2.getChannel(channelName).publish("data"); assertTrue(contextLatch.await(5, TimeUnit.SECONDS)); Assert.assertEquals(2, contexts.size()); Assert.assertNotSame(contexts.get(client1.getId()), contexts.get(client2.getId())); disconnectBayeuxClient(client1); disconnectBayeuxClient(client2); } public interface CookieConstants { public static final String COOKIE_NAME = "name"; public static final String COOKIE_VALUE = "value"; } public static class CookieWebSocketTransport extends WebSocketTransport implements CookieConstants { public CookieWebSocketTransport(BayeuxServerImpl bayeux) { super(bayeux); } @Override protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) { response.getHeaders().put("Set-Cookie", Collections.singletonList(COOKIE_NAME + "=" + COOKIE_VALUE)); } } public static class CookieJettyWebSocketTransport extends JettyWebSocketTransport implements CookieConstants { public CookieJettyWebSocketTransport(BayeuxServerImpl bayeux) { super(bayeux); } @Override protected void modifyUpgrade(ServletUpgradeRequest request, ServletUpgradeResponse response) { response.setHeader("Set-Cookie", COOKIE_NAME + "=" + COOKIE_VALUE); } } public interface SessionConstants { public static final String ATTRIBUTE_NAME = "name"; public static final String ATTRIBUTE_VALUE = "value"; } public static class SessionWebSocketTransport extends WebSocketTransport implements SessionConstants { public SessionWebSocketTransport(BayeuxServerImpl bayeux) { super(bayeux); } @Override protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) { HttpSession session = (HttpSession)request.getHttpSession(); Assert.assertNotNull(session); Assert.assertEquals(ATTRIBUTE_VALUE, session.getAttribute(ATTRIBUTE_NAME)); } } public static class SessionJettyWebSocketTransport extends JettyWebSocketTransport implements SessionConstants { public SessionJettyWebSocketTransport(BayeuxServerImpl bayeux) { super(bayeux); } @Override protected void modifyUpgrade(ServletUpgradeRequest request, ServletUpgradeResponse response) { HttpSession session = request.getSession(); Assert.assertNotNull(session); Assert.assertEquals(ATTRIBUTE_VALUE, session.getAttribute(ATTRIBUTE_NAME)); } } public static class ConcurrentBayeuxContextWebSocketTransport extends WebSocketTransport { private final AtomicInteger handshakes = new AtomicInteger(); private final CountDownLatch enterLatch = new CountDownLatch(1); private final CountDownLatch proceedLatch = new CountDownLatch(1); public ConcurrentBayeuxContextWebSocketTransport(BayeuxServerImpl bayeux) { super(bayeux); } @Override protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) { onUpgrade(handshakes, enterLatch, proceedLatch); super.modifyHandshake(request, response); } } public static class ConcurrentBayeuxContextJettyWebSocketTransport extends JettyWebSocketTransport { private final AtomicInteger handshakes = new AtomicInteger(); private final CountDownLatch enterLatch = new CountDownLatch(1); private final CountDownLatch proceedLatch = new CountDownLatch(1); public ConcurrentBayeuxContextJettyWebSocketTransport(BayeuxServerImpl bayeux) { super(bayeux); } @Override protected void modifyUpgrade(ServletUpgradeRequest request, ServletUpgradeResponse response) { onUpgrade(handshakes, enterLatch, proceedLatch); super.modifyUpgrade(request, response); } } private static void onUpgrade(AtomicInteger handshakes, CountDownLatch enterLatch, CountDownLatch proceedLatch) { int count = handshakes.incrementAndGet(); if (count == 1) { try { enterLatch.countDown(); if (!proceedLatch.await(5, TimeUnit.SECONDS)) { throw new IllegalStateException(); } } catch (InterruptedException x) { throw new IllegalStateException(x); } } } }