/* * Copyright (c) 2008-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.cometd.websocket.client; import java.io.IOException; import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.servlet.DispatcherType; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import org.cometd.bayeux.Channel; import org.cometd.bayeux.MarkedReference; import org.cometd.bayeux.Message; import org.cometd.bayeux.client.ClientSessionChannel; import org.cometd.bayeux.server.BayeuxServer; import org.cometd.bayeux.server.SecurityPolicy; import org.cometd.bayeux.server.ServerChannel; import org.cometd.bayeux.server.ServerMessage; import org.cometd.bayeux.server.ServerMessage.Mutable; import org.cometd.bayeux.server.ServerSession; import org.cometd.client.BayeuxClient; import org.cometd.client.BayeuxClient.State; import org.cometd.client.transport.ClientTransport; import org.cometd.client.transport.LongPollingTransport; import org.cometd.common.HashMapMessage; import org.cometd.server.DefaultSecurityPolicy; import org.cometd.websocket.ClientServerWebSocketTest; import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.util.BlockingArrayQueue; import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; public class BayeuxClientTest extends ClientServerWebSocketTest { public BayeuxClientTest(String implementation) { super(implementation); } @Before public void setUp() throws Exception { prepareAndStart(null); } @Test public void testHandshakeDenied() throws Exception { BayeuxClient client = newBayeuxClient(); SecurityPolicy oldPolicy = bayeux.getSecurityPolicy(); bayeux.setSecurityPolicy(new DefaultSecurityPolicy() { @Override public boolean canHandshake(BayeuxServer server, ServerSession session, ServerMessage message) { return false; } }); try { final AtomicReference<CountDownLatch> latch = new AtomicReference<>(new CountDownLatch(1)); client.getChannel(Channel.META_HANDSHAKE).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { Assert.assertFalse(message.isSuccessful()); latch.get().countDown(); } }); client.handshake(); Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); // Be sure it does not retry latch.set(new CountDownLatch(1)); Assert.assertFalse(latch.get().await(client.getBackoffIncrement() * 2, TimeUnit.MILLISECONDS)); Assert.assertTrue(client.waitFor(5000, State.DISCONNECTED)); } finally { bayeux.setSecurityPolicy(oldPolicy); disconnectBayeuxClient(client); } } @Test public void testPublish() throws Exception { final BlockingArrayQueue<String> results = new BlockingArrayQueue<>(); String channelName = "/chat/msg"; MarkedReference<ServerChannel> channel = bayeux.createChannelIfAbsent(channelName); channel.getReference().addListener(new ServerChannel.MessageListener() { @Override public boolean onMessage(ServerSession from, ServerChannel channel, Mutable message) { results.add(from.getId()); results.add(channel.getId()); results.add(String.valueOf(message.getData())); return true; } }); BayeuxClient client = newBayeuxClient(); client.handshake(); Assert.assertTrue(client.waitFor(5000, State.CONNECTED)); String data = "Hello World"; client.getChannel(channelName).publish(data); String id = results.poll(10, TimeUnit.SECONDS); Assert.assertEquals(client.getId(), id); Assert.assertEquals(channelName, results.poll(10, TimeUnit.SECONDS)); Assert.assertEquals(data, results.poll(10, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Test public void testWaitFor() throws Exception { final BlockingArrayQueue<String> results = new BlockingArrayQueue<>(); String channelName = "/chat/msg"; MarkedReference<ServerChannel> channel = bayeux.createChannelIfAbsent(channelName); channel.getReference().addListener(new ServerChannel.MessageListener() { @Override public boolean onMessage(ServerSession from, ServerChannel channel, Mutable message) { results.add(from.getId()); results.add(channel.getId()); results.add(String.valueOf(message.getData())); return true; } }); BayeuxClient client = newBayeuxClient(); long wait = 1000L; long start = System.nanoTime(); client.handshake(wait); long stop = System.nanoTime(); Assert.assertTrue(TimeUnit.NANOSECONDS.toMillis(stop - start) < wait); Assert.assertNotNull(client.getId()); String data = "Hello World"; final CountDownLatch latch = new CountDownLatch(1); client.getChannel(channelName).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { latch.countDown(); } }); client.getChannel(channelName).publish(data); Assert.assertEquals(client.getId(), results.poll(1, TimeUnit.SECONDS)); Assert.assertEquals(channelName, results.poll(1, TimeUnit.SECONDS)); Assert.assertEquals(data, results.poll(1, TimeUnit.SECONDS)); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Test public void testAuthentication() throws Exception { final AtomicReference<String> sessionId = new AtomicReference<>(); class A extends DefaultSecurityPolicy implements ServerSession.RemoveListener { @Override public boolean canHandshake(BayeuxServer server, ServerSession session, ServerMessage message) { Map<String, Object> ext = message.getExt(); if (ext == null) { return false; } Object authn = ext.get("authentication"); if (!(authn instanceof Map)) { return false; } @SuppressWarnings("unchecked") Map<String, Object> authentication = (Map<String, Object>)authn; String token = (String)authentication.get("token"); if (token == null) { return false; } sessionId.set(session.getId()); session.addListener(this); return true; } @Override public void removed(ServerSession session, boolean timeout) { sessionId.set(null); } } A authenticator = new A(); SecurityPolicy oldPolicy = bayeux.getSecurityPolicy(); bayeux.setSecurityPolicy(authenticator); try { BayeuxClient client = newBayeuxClient(); Map<String, Object> authentication = new HashMap<>(); authentication.put("token", "1234567890"); Message.Mutable fields = new HashMapMessage(); fields.getExt(true).put("authentication", authentication); client.handshake(fields); Assert.assertTrue(client.waitFor(5000, State.CONNECTED)); Assert.assertEquals(client.getId(), sessionId.get()); disconnectBayeuxClient(client); Assert.assertNull(sessionId.get()); } finally { bayeux.setSecurityPolicy(oldPolicy); } } @Test public void testClient() throws Exception { BayeuxClient client = newBayeuxClient(); final CountDownLatch handshakeLatch = new CountDownLatch(1); client.getChannel(Channel.META_HANDSHAKE).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { System.err.println("<<" + message + " @ " + channel); if (message.isSuccessful()) { handshakeLatch.countDown(); } } }); final CountDownLatch connectLatch = new CountDownLatch(1); client.getChannel(Channel.META_CONNECT).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { System.err.println("<<" + message + " @ " + channel); if (message.isSuccessful()) { connectLatch.countDown(); } } }); final CountDownLatch subscribeLatch = new CountDownLatch(1); client.getChannel(Channel.META_SUBSCRIBE).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { System.err.println("<<" + message + " @ " + channel); if (message.isSuccessful()) { subscribeLatch.countDown(); } } }); final CountDownLatch unsubscribeLatch = new CountDownLatch(1); client.getChannel(Channel.META_SUBSCRIBE).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { System.err.println("<<" + message + " @ " + channel); if (message.isSuccessful()) { unsubscribeLatch.countDown(); } } }); client.handshake(); Assert.assertTrue(handshakeLatch.await(5, TimeUnit.SECONDS)); Assert.assertTrue(connectLatch.await(5, TimeUnit.SECONDS)); final CountDownLatch publishLatch = new CountDownLatch(1); ClientSessionChannel.MessageListener subscriber = new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { System.err.println(" <" + message + " @ " + channel); publishLatch.countDown(); } }; ClientSessionChannel aChannel = client.getChannel("/a/channel"); aChannel.subscribe(subscriber); Assert.assertTrue(subscribeLatch.await(5, TimeUnit.SECONDS)); String data = "data"; aChannel.publish(data); Assert.assertTrue(publishLatch.await(5, TimeUnit.SECONDS)); aChannel.unsubscribe(subscriber); Assert.assertTrue(unsubscribeLatch.await(5, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Ignore("TODO: verify why it does not work; I suspect the setAllowedTransport() does not play since the WSUpgradeFilter kicks in first") @Test public void testHandshakeOverWebSocketReportsHTTPFailure() throws Exception { // No transports on server, to make the client fail bayeux.setAllowedTransports(); BayeuxClient client = newBayeuxClient(); final CountDownLatch latch = new CountDownLatch(1); client.getChannel(Channel.META_HANDSHAKE).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { // Verify the failure object is there @SuppressWarnings("unchecked") Map<String, Object> failure = (Map<String, Object>)message.get("failure"); Assert.assertNotNull(failure); // Verify that the transport is there Assert.assertEquals("websocket", failure.get(Message.CONNECTION_TYPE_FIELD)); // Verify the original message is there Assert.assertNotNull(failure.get("message")); // Verify the HTTP status code is there Assert.assertEquals(400, failure.get("httpCode")); // Verify the exception string is there Assert.assertNotNull(failure.get("exception")); latch.countDown(); } }); client.handshake(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Ignore("The test filter is not called because the WSUpgradeFilter is added first") @Test public void testWebSocketResponseHeadersRemoved() throws Exception { context.addFilter(new FilterHolder(new Filter() { @Override public void init(FilterConfig filterConfig) throws ServletException { } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { try { // Wrap the response to remove the header chain.doFilter(request, new HttpServletResponseWrapper((HttpServletResponse)response) { @Override public void addHeader(String name, String value) { if (!"Sec-WebSocket-Accept".equals(name)) { super.addHeader(name, value); } } }); } finally { ((HttpServletResponse)response).setHeader("Sec-WebSocket-Accept", null); } } @Override public void destroy() { } }), cometdServletPath, EnumSet.of(DispatcherType.REQUEST, DispatcherType.ASYNC)); ClientTransport webSocketTransport = newWebSocketTransport(null); ClientTransport longPollingTransport = newLongPollingTransport(null); final BayeuxClient client = new BayeuxClient(cometdURL, webSocketTransport, longPollingTransport); final CountDownLatch latch = new CountDownLatch(1); client.getChannel(Channel.META_CONNECT).addListener(new ClientSessionChannel.MessageListener() { @Override public void onMessage(ClientSessionChannel channel, Message message) { if (message.isSuccessful()) { Assert.assertEquals(LongPollingTransport.NAME, client.getTransport().getName()); latch.countDown(); } } }); client.handshake(); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); disconnectBayeuxClient(client); } @Test public void testCustomTransportURL() throws Exception { ClientTransport transport = newWebSocketTransport(cometdURL, null); // Pass a bogus URL that must not be used BayeuxClient client = new BayeuxClient("http://foo/bar", transport); client.handshake(); Assert.assertTrue(client.waitFor(5000, State.CONNECTED)); disconnectBayeuxClient(client); } }