/* * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.socket.config.annotation; import static org.junit.Assert.*; import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ScheduledThreadPoolExecutor; import org.junit.Test; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SubscribeMapping; import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.stereotype.Controller; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.config.WebSocketMessageBrokerStats; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.handler.WebSocketHandlerDecorator; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; import org.springframework.web.socket.messaging.StompSubProtocolHandler; import org.springframework.web.socket.messaging.StompTextMessageBuilder; import org.springframework.web.socket.messaging.SubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; /** * Test fixture for * {@link org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurationSupport}. * * @author Rossen Stoyanchev */ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void handlerMapping() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) config.getBean(HandlerMapping.class); assertEquals(1, hm.getOrder()); Map<String, Object> handlerMap = hm.getHandlerMap(); assertEquals(1, handlerMap.size()); assertNotNull(handlerMap.get("/simpleBroker")); } @Test public void clientInboundChannelSendMessage() throws Exception { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); TestChannel channel = config.getBean("clientInboundChannel", TestChannel.class); SubProtocolWebSocketHandler webSocketHandler = config.getBean(SubProtocolWebSocketHandler.class); List<ChannelInterceptor> interceptors = channel.getInterceptors(); assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); TestWebSocketSession session = new TestWebSocketSession("s1"); session.setOpen(true); webSocketHandler.afterConnectionEstablished(session); TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build(); webSocketHandler.handleMessage(session, textMessage); Message<?> message = channel.messages.get(0); StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); assertNotNull(accessor); assertFalse(accessor.isMutable()); assertEquals(SimpMessageType.MESSAGE, accessor.getMessageType()); assertEquals("/foo", accessor.getDestination()); } @Test public void clientOutboundChannel() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); TestChannel channel = config.getBean("clientOutboundChannel", TestChannel.class); Set<MessageHandler> handlers = channel.getSubscribers(); List<ChannelInterceptor> interceptors = channel.getInterceptors(); assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); assertEquals(1, handlers.size()); assertTrue(handlers.contains(config.getBean(SubProtocolWebSocketHandler.class))); } @Test public void brokerChannel() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); TestChannel channel = config.getBean("brokerChannel", TestChannel.class); Set<MessageHandler> handlers = channel.getSubscribers(); List<ChannelInterceptor> interceptors = channel.getInterceptors(); assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); assertEquals(2, handlers.size()); assertTrue(handlers.contains(config.getBean(SimpleBrokerMessageHandler.class))); assertTrue(handlers.contains(config.getBean(UserDestinationMessageHandler.class))); } @Test public void webSocketHandler() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); SubProtocolWebSocketHandler subWsHandler = config.getBean(SubProtocolWebSocketHandler.class); assertEquals(1024 * 1024, subWsHandler.getSendBufferSizeLimit()); assertEquals(25 * 1000, subWsHandler.getSendTimeLimit()); Map<String, SubProtocolHandler> handlerMap = subWsHandler.getProtocolHandlerMap(); StompSubProtocolHandler protocolHandler = (StompSubProtocolHandler) handlerMap.get("v12.stomp"); assertEquals(128 * 1024, protocolHandler.getMessageSizeLimit()); } @Test public void taskScheduler() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); String name = "messageBrokerSockJsTaskScheduler"; ThreadPoolTaskScheduler taskScheduler = config.getBean(name, ThreadPoolTaskScheduler.class); ScheduledThreadPoolExecutor executor = taskScheduler.getScheduledThreadPoolExecutor(); assertEquals(Runtime.getRuntime().availableProcessors(), executor.getCorePoolSize()); assertTrue(executor.getRemoveOnCancelPolicy()); SimpleBrokerMessageHandler handler = config.getBean(SimpleBrokerMessageHandler.class); assertNotNull(handler.getTaskScheduler()); assertArrayEquals(new long[] {15000, 15000}, handler.getHeartbeatValue()); } @Test public void webSocketMessageBrokerStats() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); String name = "webSocketMessageBrokerStats"; WebSocketMessageBrokerStats stats = config.getBean(name, WebSocketMessageBrokerStats.class); String actual = stats.toString(); String expected = "WebSocketSession\\[0 current WS\\(0\\)-HttpStream\\(0\\)-HttpPoll\\(0\\), " + "0 total, 0 closed abnormally \\(0 connect failure, 0 send limit, 0 transport error\\)\\], " + "stompSubProtocol\\[processed CONNECT\\(0\\)-CONNECTED\\(0\\)-DISCONNECT\\(0\\)\\], " + "stompBrokerRelay\\[null\\], " + "inboundChannel\\[pool size = \\d, active threads = \\d, queued tasks = \\d, completed tasks = \\d\\], " + "outboundChannelpool size = \\d, active threads = \\d, queued tasks = \\d, completed tasks = \\d\\], " + "sockJsScheduler\\[pool size = \\d, active threads = \\d, queued tasks = \\d, completed tasks = \\d\\]"; assertTrue("\nExpected: " + expected.replace("\\", "") + "\n Actual: " + actual, actual.matches(expected)); } @Test public void webSocketHandlerDecorator() throws Exception { ApplicationContext config = createConfig(WebSocketHandlerDecoratorConfig.class); WebSocketHandler handler = config.getBean(SubProtocolWebSocketHandler.class); assertNotNull(handler); SimpleUrlHandlerMapping mapping = (SimpleUrlHandlerMapping) config.getBean("stompWebSocketHandlerMapping"); WebSocketHttpRequestHandler httpHandler = (WebSocketHttpRequestHandler) mapping.getHandlerMap().get("/test"); handler = httpHandler.getWebSocketHandler(); WebSocketSession session = new TestWebSocketSession("id"); handler.afterConnectionEstablished(session); assertEquals(true, session.getAttributes().get("decorated")); } private ApplicationContext createConfig(Class<?>... configClasses) { AnnotationConfigApplicationContext config = new AnnotationConfigApplicationContext(); config.register(configClasses); config.refresh(); return config; } @Controller static class TestController { @SubscribeMapping("/foo") public String handleSubscribe() { return "bar"; } @MessageMapping("/foo") @SendTo("/bar") public String handleMessage() { return "bar"; } } @Configuration static class TestConfigurer extends AbstractWebSocketMessageBrokerConfigurer { @Bean public TestController subscriptionController() { return new TestController(); } @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/simpleBroker"); } @Override public void configureWebSocketTransport(WebSocketTransportRegistration registration) { registration.setMessageSizeLimit(128 * 1024); registration.setSendTimeLimit(25 * 1000); registration.setSendBufferSizeLimit(1024 * 1024); } @Override public void configureMessageBroker(MessageBrokerRegistry registry) { registry.enableSimpleBroker() .setTaskScheduler(mock(TaskScheduler.class)) .setHeartbeatValue(new long[] {15000, 15000}); } } @Configuration static class TestChannelConfig extends DelegatingWebSocketMessageBrokerConfiguration { @Override @Bean public AbstractSubscribableChannel clientInboundChannel() { TestChannel channel = new TestChannel(); channel.setInterceptors(super.clientInboundChannel().getInterceptors()); return channel; } @Override @Bean public AbstractSubscribableChannel clientOutboundChannel() { TestChannel channel = new TestChannel(); channel.setInterceptors(super.clientOutboundChannel().getInterceptors()); return channel; } @Override public AbstractSubscribableChannel brokerChannel() { TestChannel channel = new TestChannel(); channel.setInterceptors(super.brokerChannel().getInterceptors()); return channel; } } @Configuration static class WebSocketHandlerDecoratorConfig extends WebSocketMessageBrokerConfigurationSupport { @Override protected void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/test"); } @Override protected void configureWebSocketTransport(WebSocketTransportRegistration registry) { registry.addDecoratorFactory(new WebSocketHandlerDecoratorFactory() { @Override public WebSocketHandlerDecorator decorate(WebSocketHandler handler) { return new WebSocketHandlerDecorator(handler) { @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { session.getAttributes().put("decorated", true); super.afterConnectionEstablished(session); } }; } }); } } private static class TestChannel extends ExecutorSubscribableChannel { private final List<Message<?>> messages = new ArrayList<>(); @Override public boolean sendInternal(Message<?> message, long timeout) { this.messages.add(message); return true; } } }