/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.tests.Assume;
import org.springframework.tests.TestGroup;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Abstract base class for integration tests using the
* {@link org.springframework.web.socket.sockjs.client.SockJsClient SockJsClient}
* against actual SockJS server endpoints.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
*/
public abstract class AbstractSockJsIntegrationTests {
@Rule
public final TestName testName = new TestName();
protected Log logger = LogFactory.getLog(getClass());
private SockJsClient sockJsClient;
private WebSocketTestServer server;
private AnnotationConfigWebApplicationContext wac;
private TestFilter testFilter;
private String baseUrl;
@BeforeClass
public static void performanceTestGroupAssumption() throws Exception {
Assume.group(TestGroup.PERFORMANCE);
}
@Before
public void setup() throws Exception {
logger.debug("Setting up '" + this.testName.getMethodName() + "'");
this.testFilter = new TestFilter();
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(TestConfig.class, upgradeStrategyConfigClass());
this.server = createWebSocketTestServer();
this.server.setup();
this.server.deployConfig(this.wac, this.testFilter);
this.server.start();
this.wac.setServletContext(this.server.getServletContext());
this.wac.refresh();
this.baseUrl = "http://localhost:" + this.server.getPort();
}
@After
public void teardown() throws Exception {
try {
this.sockJsClient.stop();
}
catch (Throwable ex) {
logger.error("Failed to stop SockJsClient", ex);
}
try {
this.server.undeployConfig();
}
catch (Throwable t) {
logger.error("Failed to undeploy application config", t);
}
try {
this.server.stop();
}
catch (Throwable t) {
logger.error("Failed to stop server", t);
}
try {
this.wac.close();
}
catch (Throwable t) {
logger.error("Failed to close WebApplicationContext", t);
}
}
protected abstract Class<?> upgradeStrategyConfigClass();
protected abstract WebSocketTestServer createWebSocketTestServer();
protected abstract Transport createWebSocketTransport();
protected abstract AbstractXhrTransport createXhrTransport();
protected void initSockJsClient(Transport... transports) {
this.sockJsClient = new SockJsClient(Arrays.asList(transports));
this.sockJsClient.start();
}
@Test
public void echoWebSocket() throws Exception {
testEcho(100, createWebSocketTransport(), null);
}
@Test
public void echoXhrStreaming() throws Exception {
testEcho(100, createXhrTransport(), null);
}
@Test
public void echoXhr() throws Exception {
AbstractXhrTransport xhrTransport = createXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testEcho(100, xhrTransport, null);
}
// SPR-13254
@Test
public void echoXhrWithHeaders() throws Exception {
AbstractXhrTransport xhrTransport = createXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("auth", "123");
testEcho(10, xhrTransport, headers);
for (Map.Entry<String, HttpHeaders> entry : this.testFilter.requests.entrySet()) {
HttpHeaders httpHeaders = entry.getValue();
assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth"));
}
}
@Test
public void receiveOneMessageWebSocket() throws Exception {
testReceiveOneMessage(createWebSocketTransport(), null);
}
@Test
public void receiveOneMessageXhrStreaming() throws Exception {
testReceiveOneMessage(createXhrTransport(), null);
}
@Test
public void receiveOneMessageXhr() throws Exception {
AbstractXhrTransport xhrTransport = createXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testReceiveOneMessage(xhrTransport, null);
}
@Test
public void infoRequestFailure() throws Exception {
TestClientHandler handler = new TestClientHandler();
this.testFilter.sendErrorMap.put("/info", 500);
CountDownLatch latch = new CountDownLatch(1);
initSockJsClient(createWebSocketTransport());
this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").addCallback(
new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession result) {
}
@Override
public void onFailure(Throwable ex) {
latch.countDown();
}
}
);
assertTrue(latch.await(5000, TimeUnit.MILLISECONDS));
}
@Test
public void fallbackAfterTransportFailure() throws Exception {
this.testFilter.sendErrorMap.put("/websocket", 200);
this.testFilter.sendErrorMap.put("/xhr_streaming", 500);
TestClientHandler handler = new TestClientHandler();
initSockJsClient(createWebSocketTransport(), createXhrTransport());
WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get();
assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass());
TextMessage message = new TextMessage("message1");
session.sendMessage(message);
handler.awaitMessage(message, 5000);
}
@Test(timeout = 5000)
public void fallbackAfterConnectTimeout() throws Exception {
TestClientHandler clientHandler = new TestClientHandler();
this.testFilter.sleepDelayMap.put("/xhr_streaming", 10000L);
this.testFilter.sendErrorMap.put("/xhr_streaming", 503);
initSockJsClient(createXhrTransport());
this.sockJsClient.setConnectTimeoutScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class));
WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get();
assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass());
TextMessage message = new TextMessage("message1");
clientSession.sendMessage(message);
clientHandler.awaitMessage(message, 5000);
clientSession.close();
}
private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception {
List<TextMessage> messages = new ArrayList<>();
for (int i = 0; i < messageCount; i++) {
messages.add(new TextMessage("m" + i));
}
TestClientHandler handler = new TestClientHandler();
initSockJsClient(transport);
URI url = new URI(this.baseUrl + "/echo");
WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get();
for (TextMessage message : messages) {
session.sendMessage(message);
}
handler.awaitMessageCount(messageCount, 5000);
for (TextMessage message : messages) {
assertTrue("Message not received: " + message, handler.receivedMessages.remove(message));
}
assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size());
session.close();
}
private void testReceiveOneMessage(Transport transport, WebSocketHttpHeaders headers)
throws Exception {
TestClientHandler clientHandler = new TestClientHandler();
initSockJsClient(transport);
this.sockJsClient.doHandshake(clientHandler, headers, new URI(this.baseUrl + "/test")).get();
TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class);
assertNotNull("afterConnectionEstablished should have been called", clientHandler.session);
serverHandler.awaitSession(5000);
TextMessage message = new TextMessage("message1");
serverHandler.session.sendMessage(message);
clientHandler.awaitMessage(message, 5000);
}
private static void awaitEvent(BooleanSupplier condition, long timeToWait, String description) {
long timeToSleep = 200;
for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) {
if (condition.getAsBoolean()) {
return;
}
try {
Thread.sleep(timeToSleep);
}
catch (InterruptedException e) {
throw new IllegalStateException("Interrupted while waiting for " + description, e);
}
}
throw new IllegalStateException("Timed out waiting for " + description);
}
@Configuration
@EnableWebSocket
static class TestConfig implements WebSocketConfigurer {
@Autowired
private RequestUpgradeStrategy upgradeStrategy;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy);
registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS();
registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS();
}
@Bean
public TestServerHandler testServerHandler() {
return new TestServerHandler();
}
}
private static class TestClientHandler extends TextWebSocketHandler {
private final BlockingQueue<TextMessage> receivedMessages = new LinkedBlockingQueue<>();
private volatile WebSocketSession session;
private volatile Throwable transportError;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
this.receivedMessages.add(message);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
this.transportError = exception;
}
public void awaitMessageCount(final int count, long timeToWait) throws Exception {
awaitEvent(() -> receivedMessages.size() >= count, timeToWait,
count + " number of messages. Received so far: " + this.receivedMessages);
}
public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException {
TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS);
if (actual != null) {
assertEquals(expected, actual);
}
else if (this.transportError != null) {
throw new AssertionError("Transport error", this.transportError);
}
else {
fail("Timed out waiting for [" + expected + "]");
}
}
}
private static class EchoHandler extends TextWebSocketHandler {
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
session.sendMessage(message);
}
}
private static class TestServerHandler extends TextWebSocketHandler {
private WebSocketSession session;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
}
public WebSocketSession awaitSession(long timeToWait) throws InterruptedException {
awaitEvent(() -> this.session != null, timeToWait, " session");
return this.session;
}
}
private static class TestFilter implements Filter {
private final Map<String, HttpHeaders> requests = new HashMap<>();
private final Map<String, Long> sleepDelayMap = new HashMap<>();
private final Map<String, Integer> sendErrorMap = new HashMap<>();
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) request;
String uri = httpRequest.getRequestURI();
HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders();
this.requests.put(uri, headers);
for (String suffix : this.sleepDelayMap.keySet()) {
if ((httpRequest).getRequestURI().endsWith(suffix)) {
try {
Thread.sleep(this.sleepDelayMap.get(suffix));
break;
}
catch (InterruptedException e) {
e.printStackTrace();
}
}
}
for (String suffix : this.sendErrorMap.keySet()) {
if ((httpRequest).getRequestURI().endsWith(suffix)) {
((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix));
return;
}
}
chain.doFilter(request, response);
}
@Override
public void init(FilterConfig filterConfig) {
}
@Override
public void destroy() {
}
}
}