/* * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.socket.sockjs.client; import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.function.BooleanSupplier; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.tests.Assume; import org.springframework.tests.TestGroup; import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketTestServer; import org.springframework.web.socket.config.annotation.EnableWebSocket; import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.RequestUpgradeStrategy; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Abstract base class for integration tests using the * {@link org.springframework.web.socket.sockjs.client.SockJsClient SockJsClient} * against actual SockJS server endpoints. * * @author Rossen Stoyanchev * @author Sam Brannen */ public abstract class AbstractSockJsIntegrationTests { @Rule public final TestName testName = new TestName(); protected Log logger = LogFactory.getLog(getClass()); private SockJsClient sockJsClient; private WebSocketTestServer server; private AnnotationConfigWebApplicationContext wac; private TestFilter testFilter; private String baseUrl; @BeforeClass public static void performanceTestGroupAssumption() throws Exception { Assume.group(TestGroup.PERFORMANCE); } @Before public void setup() throws Exception { logger.debug("Setting up '" + this.testName.getMethodName() + "'"); this.testFilter = new TestFilter(); this.wac = new AnnotationConfigWebApplicationContext(); this.wac.register(TestConfig.class, upgradeStrategyConfigClass()); this.server = createWebSocketTestServer(); this.server.setup(); this.server.deployConfig(this.wac, this.testFilter); this.server.start(); this.wac.setServletContext(this.server.getServletContext()); this.wac.refresh(); this.baseUrl = "http://localhost:" + this.server.getPort(); } @After public void teardown() throws Exception { try { this.sockJsClient.stop(); } catch (Throwable ex) { logger.error("Failed to stop SockJsClient", ex); } try { this.server.undeployConfig(); } catch (Throwable t) { logger.error("Failed to undeploy application config", t); } try { this.server.stop(); } catch (Throwable t) { logger.error("Failed to stop server", t); } try { this.wac.close(); } catch (Throwable t) { logger.error("Failed to close WebApplicationContext", t); } } protected abstract Class<?> upgradeStrategyConfigClass(); protected abstract WebSocketTestServer createWebSocketTestServer(); protected abstract Transport createWebSocketTransport(); protected abstract AbstractXhrTransport createXhrTransport(); protected void initSockJsClient(Transport... transports) { this.sockJsClient = new SockJsClient(Arrays.asList(transports)); this.sockJsClient.start(); } @Test public void echoWebSocket() throws Exception { testEcho(100, createWebSocketTransport(), null); } @Test public void echoXhrStreaming() throws Exception { testEcho(100, createXhrTransport(), null); } @Test public void echoXhr() throws Exception { AbstractXhrTransport xhrTransport = createXhrTransport(); xhrTransport.setXhrStreamingDisabled(true); testEcho(100, xhrTransport, null); } // SPR-13254 @Test public void echoXhrWithHeaders() throws Exception { AbstractXhrTransport xhrTransport = createXhrTransport(); xhrTransport.setXhrStreamingDisabled(true); WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); headers.add("auth", "123"); testEcho(10, xhrTransport, headers); for (Map.Entry<String, HttpHeaders> entry : this.testFilter.requests.entrySet()) { HttpHeaders httpHeaders = entry.getValue(); assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth")); } } @Test public void receiveOneMessageWebSocket() throws Exception { testReceiveOneMessage(createWebSocketTransport(), null); } @Test public void receiveOneMessageXhrStreaming() throws Exception { testReceiveOneMessage(createXhrTransport(), null); } @Test public void receiveOneMessageXhr() throws Exception { AbstractXhrTransport xhrTransport = createXhrTransport(); xhrTransport.setXhrStreamingDisabled(true); testReceiveOneMessage(xhrTransport, null); } @Test public void infoRequestFailure() throws Exception { TestClientHandler handler = new TestClientHandler(); this.testFilter.sendErrorMap.put("/info", 500); CountDownLatch latch = new CountDownLatch(1); initSockJsClient(createWebSocketTransport()); this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").addCallback( new ListenableFutureCallback<WebSocketSession>() { @Override public void onSuccess(WebSocketSession result) { } @Override public void onFailure(Throwable ex) { latch.countDown(); } } ); assertTrue(latch.await(5000, TimeUnit.MILLISECONDS)); } @Test public void fallbackAfterTransportFailure() throws Exception { this.testFilter.sendErrorMap.put("/websocket", 200); this.testFilter.sendErrorMap.put("/xhr_streaming", 500); TestClientHandler handler = new TestClientHandler(); initSockJsClient(createWebSocketTransport(), createXhrTransport()); WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get(); assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass()); TextMessage message = new TextMessage("message1"); session.sendMessage(message); handler.awaitMessage(message, 5000); } @Test(timeout = 5000) public void fallbackAfterConnectTimeout() throws Exception { TestClientHandler clientHandler = new TestClientHandler(); this.testFilter.sleepDelayMap.put("/xhr_streaming", 10000L); this.testFilter.sendErrorMap.put("/xhr_streaming", 503); initSockJsClient(createXhrTransport()); this.sockJsClient.setConnectTimeoutScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class)); WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get(); assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass()); TextMessage message = new TextMessage("message1"); clientSession.sendMessage(message); clientHandler.awaitMessage(message, 5000); clientSession.close(); } private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception { List<TextMessage> messages = new ArrayList<>(); for (int i = 0; i < messageCount; i++) { messages.add(new TextMessage("m" + i)); } TestClientHandler handler = new TestClientHandler(); initSockJsClient(transport); URI url = new URI(this.baseUrl + "/echo"); WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get(); for (TextMessage message : messages) { session.sendMessage(message); } handler.awaitMessageCount(messageCount, 5000); for (TextMessage message : messages) { assertTrue("Message not received: " + message, handler.receivedMessages.remove(message)); } assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size()); session.close(); } private void testReceiveOneMessage(Transport transport, WebSocketHttpHeaders headers) throws Exception { TestClientHandler clientHandler = new TestClientHandler(); initSockJsClient(transport); this.sockJsClient.doHandshake(clientHandler, headers, new URI(this.baseUrl + "/test")).get(); TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class); assertNotNull("afterConnectionEstablished should have been called", clientHandler.session); serverHandler.awaitSession(5000); TextMessage message = new TextMessage("message1"); serverHandler.session.sendMessage(message); clientHandler.awaitMessage(message, 5000); } private static void awaitEvent(BooleanSupplier condition, long timeToWait, String description) { long timeToSleep = 200; for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) { if (condition.getAsBoolean()) { return; } try { Thread.sleep(timeToSleep); } catch (InterruptedException e) { throw new IllegalStateException("Interrupted while waiting for " + description, e); } } throw new IllegalStateException("Timed out waiting for " + description); } @Configuration @EnableWebSocket static class TestConfig implements WebSocketConfigurer { @Autowired private RequestUpgradeStrategy upgradeStrategy; @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy); registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS(); registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS(); } @Bean public TestServerHandler testServerHandler() { return new TestServerHandler(); } } private static class TestClientHandler extends TextWebSocketHandler { private final BlockingQueue<TextMessage> receivedMessages = new LinkedBlockingQueue<>(); private volatile WebSocketSession session; private volatile Throwable transportError; @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.session = session; } @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { this.receivedMessages.add(message); } @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { this.transportError = exception; } public void awaitMessageCount(final int count, long timeToWait) throws Exception { awaitEvent(() -> receivedMessages.size() >= count, timeToWait, count + " number of messages. Received so far: " + this.receivedMessages); } public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException { TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS); if (actual != null) { assertEquals(expected, actual); } else if (this.transportError != null) { throw new AssertionError("Transport error", this.transportError); } else { fail("Timed out waiting for [" + expected + "]"); } } } private static class EchoHandler extends TextWebSocketHandler { @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { session.sendMessage(message); } } private static class TestServerHandler extends TextWebSocketHandler { private WebSocketSession session; @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.session = session; } public WebSocketSession awaitSession(long timeToWait) throws InterruptedException { awaitEvent(() -> this.session != null, timeToWait, " session"); return this.session; } } private static class TestFilter implements Filter { private final Map<String, HttpHeaders> requests = new HashMap<>(); private final Map<String, Long> sleepDelayMap = new HashMap<>(); private final Map<String, Integer> sendErrorMap = new HashMap<>(); @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest httpRequest = (HttpServletRequest) request; String uri = httpRequest.getRequestURI(); HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders(); this.requests.put(uri, headers); for (String suffix : this.sleepDelayMap.keySet()) { if ((httpRequest).getRequestURI().endsWith(suffix)) { try { Thread.sleep(this.sleepDelayMap.get(suffix)); break; } catch (InterruptedException e) { e.printStackTrace(); } } } for (String suffix : this.sendErrorMap.keySet()) { if ((httpRequest).getRequestURI().endsWith(suffix)) { ((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix)); return; } } chain.doFilter(request, response); } @Override public void init(FilterConfig filterConfig) { } @Override public void destroy() { } } }