/* * Copyright 2014-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.websocket.config; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import java.net.URI; import java.net.URISyntaxException; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.integration.support.converter.MapMessageConverter; import org.springframework.integration.support.converter.SimpleMessageConverter; import org.springframework.integration.test.util.TestUtils; import org.springframework.integration.websocket.IntegrationWebSocketContainer; import org.springframework.integration.websocket.inbound.WebSocketInboundChannelAdapter; import org.springframework.integration.websocket.outbound.WebSocketOutboundMessageHandler; import org.springframework.integration.websocket.support.PassThruSubProtocolHandler; import org.springframework.integration.websocket.support.SubProtocolHandlerRegistry; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.converter.CompositeMessageConverter; import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.StringMessageConverter; import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; import org.springframework.scheduling.TaskScheduler; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; import org.springframework.web.socket.messaging.StompSubProtocolHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; import org.springframework.web.socket.sockjs.transport.TransportType; /** * @author Artem Bilan * @since 4.1 */ @ContextConfiguration @RunWith(SpringJUnit4ClassRunner.class) @DirtiesContext public class WebSocketParserTests { @Autowired @Qualifier("integrationWebSocketHandlerMapping") private HandlerMapping handlerMapping; @Autowired @Qualifier("serverWebSocketContainer") private IntegrationWebSocketContainer serverWebSocketContainer; @Autowired private TaskScheduler taskScheduler; @Autowired private HandshakeHandler handshakeHandler; @Autowired private HandshakeInterceptor handshakeInterceptor; @Autowired private SockJsMessageCodec sockJsMessageCodec; @Autowired @Qualifier("defaultInboundAdapter.adapter") private WebSocketInboundChannelAdapter defaultInboundAdapter; @Autowired private AbstractBrokerMessageHandler brokerHandler; @Autowired @Qualifier("clientWebSocketContainer") private IntegrationWebSocketContainer clientWebSocketContainer; @Autowired @Qualifier("simpleClientWebSocketContainer") private IntegrationWebSocketContainer simpleClientWebSocketContainer; @Autowired @Qualifier("customInboundAdapter") private WebSocketInboundChannelAdapter customInboundAdapter; @Autowired private MessageChannel clientInboundChannel; @Autowired private MessageChannel errorChannel; @Autowired private StompSubProtocolHandler stompSubProtocolHandler; @Autowired private SimpleMessageConverter simpleMessageConverter; @Autowired private MapMessageConverter mapMessageConverter; @Autowired private WebSocketClient webSocketClient; @Autowired @Qualifier("defaultOutboundAdapter.handler") private WebSocketOutboundMessageHandler defaultOutboundAdapter; @Autowired @Qualifier("customOutboundAdapter.handler") private WebSocketOutboundMessageHandler customOutboundAdapter; @Autowired private WebSocketHandlerDecoratorFactory decoratorFactory; @Test @SuppressWarnings("unckecked") public void testDefaultInboundChannelAdapterAndServerContainer() { Map<?, ?> urlMap = TestUtils.getPropertyValue(this.handlerMapping, "urlMap", Map.class); assertEquals(1, urlMap.size()); assertTrue(urlMap.containsKey("/ws/**")); Object mappedHandler = urlMap.get("/ws/**"); //WebSocketHttpRequestHandler -> ExceptionWebSocketHandlerDecorator - > LoggingWebSocketHandlerDecorator // -> IntegrationWebSocketContainer$IntegrationWebSocketHandler assertSame(TestUtils.getPropertyValue(this.serverWebSocketContainer, "webSocketHandler"), TestUtils.getPropertyValue(mappedHandler, "webSocketHandler.delegate.delegate")); assertSame(this.handshakeHandler, TestUtils.getPropertyValue(this.serverWebSocketContainer, "handshakeHandler")); HandshakeInterceptor[] interceptors = TestUtils.getPropertyValue(this.serverWebSocketContainer, "interceptors", HandshakeInterceptor[].class); assertNotNull(interceptors); assertEquals(1, interceptors.length); assertSame(this.handshakeInterceptor, interceptors[0]); assertEquals(100, TestUtils.getPropertyValue(this.serverWebSocketContainer, "sendTimeLimit")); assertEquals(100000, TestUtils.getPropertyValue(this.serverWebSocketContainer, "sendBufferSizeLimit")); assertArrayEquals(new String[] {"http://foo.com"}, TestUtils.getPropertyValue(this.serverWebSocketContainer, "origins", String[].class)); WebSocketHandlerDecoratorFactory[] decoratorFactories = TestUtils.getPropertyValue(this.serverWebSocketContainer, "decoratorFactories", WebSocketHandlerDecoratorFactory[].class); assertNotNull(decoratorFactories); assertEquals(1, decoratorFactories.length); assertSame(this.decoratorFactory, decoratorFactories[0]); TransportHandlingSockJsService sockJsService = TestUtils.getPropertyValue(mappedHandler, "sockJsService", TransportHandlingSockJsService.class); assertSame(this.taskScheduler, sockJsService.getTaskScheduler()); assertSame(this.sockJsMessageCodec, sockJsService.getMessageCodec()); Map<TransportType, TransportHandler> transportHandlers = sockJsService.getTransportHandlers(); //If "handshake-handler" is provided, "transport-handlers" isn't allowed assertEquals(8, transportHandlers.size()); assertSame(this.handshakeHandler, TestUtils.getPropertyValue(transportHandlers.get(TransportType.WEBSOCKET), "handshakeHandler")); assertEquals(4000L, sockJsService.getDisconnectDelay()); assertEquals(30000L, sockJsService.getHeartbeatTime()); assertEquals(10000, sockJsService.getHttpMessageCacheSize()); assertEquals(2000, sockJsService.getStreamBytesLimit()); assertEquals("https://foo.sock.js", sockJsService.getSockJsClientLibraryUrl()); assertFalse(sockJsService.isSessionCookieNeeded()); assertFalse(sockJsService.isWebSocketEnabled()); assertTrue(sockJsService.shouldSuppressCors()); assertSame(this.serverWebSocketContainer, TestUtils.getPropertyValue(this.defaultInboundAdapter, "webSocketContainer")); assertNull(TestUtils.getPropertyValue(this.defaultInboundAdapter, "messageConverters")); assertEquals(TestUtils.getPropertyValue(this.defaultInboundAdapter, "messageConverter.converters"), TestUtils.getPropertyValue(this.defaultInboundAdapter, "defaultConverters")); assertEquals(String.class, TestUtils.getPropertyValue(this.defaultInboundAdapter, "payloadType", AtomicReference.class).get()); assertTrue(TestUtils.getPropertyValue(this.defaultInboundAdapter, "useBroker", Boolean.class)); assertSame(this.brokerHandler, TestUtils.getPropertyValue(this.defaultInboundAdapter, "brokerHandler")); SubProtocolHandlerRegistry subProtocolHandlerRegistry = TestUtils.getPropertyValue(this.defaultInboundAdapter, "subProtocolHandlerRegistry", SubProtocolHandlerRegistry.class); assertThat(TestUtils.getPropertyValue(subProtocolHandlerRegistry, "defaultProtocolHandler"), instanceOf(PassThruSubProtocolHandler.class)); assertTrue(TestUtils.getPropertyValue(subProtocolHandlerRegistry, "protocolHandlers", Map.class).isEmpty()); } @Test public void testCustomInboundChannelAdapterAndClientContainer() throws URISyntaxException { assertSame(this.clientInboundChannel, TestUtils.getPropertyValue(this.customInboundAdapter, "outputChannel")); assertSame(this.errorChannel, TestUtils.getPropertyValue(this.customInboundAdapter, "errorChannel")); assertSame(this.clientWebSocketContainer, TestUtils.getPropertyValue(this.customInboundAdapter, "webSocketContainer")); assertEquals(2000L, TestUtils.getPropertyValue(this.customInboundAdapter, "messagingTemplate.sendTimeout")); assertEquals(200, TestUtils.getPropertyValue(this.customInboundAdapter, "phase")); assertFalse(TestUtils.getPropertyValue(this.customInboundAdapter, "autoStartup", Boolean.class)); assertEquals(Integer.class, TestUtils.getPropertyValue(this.customInboundAdapter, "payloadType", AtomicReference.class).get()); SubProtocolHandlerRegistry subProtocolHandlerRegistry = TestUtils.getPropertyValue(this.customInboundAdapter, "subProtocolHandlerRegistry", SubProtocolHandlerRegistry.class); assertSame(this.stompSubProtocolHandler, TestUtils.getPropertyValue(subProtocolHandlerRegistry, "defaultProtocolHandler")); Map<?, ?> protocolHandlers = TestUtils.getPropertyValue(subProtocolHandlerRegistry, "protocolHandlers", Map.class); assertEquals(3, protocolHandlers.size()); //PassThruSubProtocolHandler is ignored because it doesn't provide any 'protocol' by default. //See warn log message. for (Object handler : protocolHandlers.values()) { assertSame(this.stompSubProtocolHandler, handler); } assertTrue(TestUtils.getPropertyValue(this.customInboundAdapter, "mergeWithDefaultConverters", Boolean.class)); CompositeMessageConverter compositeMessageConverter = TestUtils.getPropertyValue(this.customInboundAdapter, "messageConverter", CompositeMessageConverter.class); List<MessageConverter> converters = compositeMessageConverter.getConverters(); assertEquals(5, converters.size()); assertSame(this.simpleMessageConverter, converters.get(0)); assertSame(this.mapMessageConverter, converters.get(1)); assertThat(converters.get(2), instanceOf(StringMessageConverter.class)); //Test ClientWebSocketContainer parser assertSame(this.customInboundAdapter, TestUtils.getPropertyValue(this.clientWebSocketContainer, "messageListener")); assertEquals(100, TestUtils.getPropertyValue(this.clientWebSocketContainer, "sendTimeLimit")); assertEquals(1000, TestUtils.getPropertyValue(this.clientWebSocketContainer, "sendBufferSizeLimit")); assertEquals(new URI("ws://foo.bar/ws?service=user"), TestUtils.getPropertyValue(this.clientWebSocketContainer, "connectionManager.uri", URI.class)); assertSame(this.webSocketClient, TestUtils.getPropertyValue(this.clientWebSocketContainer, "connectionManager.client")); assertEquals(100, TestUtils.getPropertyValue(this.clientWebSocketContainer, "connectionManager.phase")); WebSocketHttpHeaders headers = TestUtils.getPropertyValue(this.clientWebSocketContainer, "headers", WebSocketHttpHeaders.class); assertEquals("FOO", headers.getOrigin()); assertEquals(Arrays.asList("BAR", "baz"), headers.get("FOO")); assertEquals(10 * 1000, TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "sendTimeLimit")); assertEquals(512 * 1024, TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "sendBufferSizeLimit")); assertEquals(new URI("ws://foo.bar"), TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "connectionManager.uri", URI.class)); assertSame(this.webSocketClient, TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "connectionManager.client")); assertEquals(Integer.MAX_VALUE, TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "connectionManager.phase")); assertFalse(TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "connectionManager.autoStartup", Boolean.class)); assertTrue(TestUtils.getPropertyValue(this.simpleClientWebSocketContainer, "headers", WebSocketHttpHeaders.class).isEmpty()); } @Test public void testDefaultOutboundChannelAdapter() { assertSame(this.serverWebSocketContainer, TestUtils.getPropertyValue(this.defaultOutboundAdapter, "webSocketContainer")); assertNull(TestUtils.getPropertyValue(this.defaultOutboundAdapter, "messageConverters")); assertEquals(TestUtils.getPropertyValue(this.defaultOutboundAdapter, "messageConverter.converters"), TestUtils.getPropertyValue(this.defaultOutboundAdapter, "defaultConverters")); SubProtocolHandlerRegistry subProtocolHandlerRegistry = TestUtils.getPropertyValue(this.defaultOutboundAdapter, "subProtocolHandlerRegistry", SubProtocolHandlerRegistry.class); assertThat(TestUtils.getPropertyValue(subProtocolHandlerRegistry, "defaultProtocolHandler"), instanceOf(PassThruSubProtocolHandler.class)); assertTrue(TestUtils.getPropertyValue(subProtocolHandlerRegistry, "protocolHandlers", Map.class).isEmpty()); assertFalse(TestUtils.getPropertyValue(this.defaultOutboundAdapter, "client", Boolean.class)); } @Test public void testCustomOutboundChannelAdapter() throws URISyntaxException { assertSame(this.clientWebSocketContainer, TestUtils.getPropertyValue(this.customOutboundAdapter, "webSocketContainer")); SubProtocolHandlerRegistry subProtocolHandlerRegistry = TestUtils.getPropertyValue(this.customOutboundAdapter, "subProtocolHandlerRegistry", SubProtocolHandlerRegistry.class); assertSame(this.stompSubProtocolHandler, TestUtils.getPropertyValue(subProtocolHandlerRegistry, "defaultProtocolHandler")); Map<?, ?> protocolHandlers = TestUtils.getPropertyValue(subProtocolHandlerRegistry, "protocolHandlers", Map.class); assertEquals(3, protocolHandlers.size()); //PassThruSubProtocolHandler is ignored because it doesn't provide any 'protocol' by default. //See warn log message. for (Object handler : protocolHandlers.values()) { assertSame(this.stompSubProtocolHandler, handler); } assertTrue(TestUtils.getPropertyValue(this.customOutboundAdapter, "mergeWithDefaultConverters", Boolean.class)); CompositeMessageConverter compositeMessageConverter = TestUtils.getPropertyValue(this.customOutboundAdapter, "messageConverter", CompositeMessageConverter.class); List<MessageConverter> converters = compositeMessageConverter.getConverters(); assertEquals(5, converters.size()); assertSame(this.simpleMessageConverter, converters.get(0)); assertSame(this.mapMessageConverter, converters.get(1)); assertThat(converters.get(2), instanceOf(StringMessageConverter.class)); assertTrue(TestUtils.getPropertyValue(this.customOutboundAdapter, "client", Boolean.class)); } private static class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory { @Override public WebSocketHandler decorate(WebSocketHandler handler) { return handler; } } }