/* * Copyright 2015-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.stomp.inbound; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import java.util.Collections; import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationListener; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.config.EnableIntegration; import org.springframework.integration.event.inbound.ApplicationEventListeningMessageProducer; import org.springframework.integration.stomp.StompSessionManager; import org.springframework.integration.stomp.WebSocketStompSessionManager; import org.springframework.integration.stomp.event.StompConnectionFailedEvent; import org.springframework.integration.stomp.event.StompIntegrationEvent; import org.springframework.integration.stomp.event.StompReceiptEvent; import org.springframework.integration.stomp.event.StompSessionConnectedEvent; import org.springframework.integration.test.support.LogAdjustingTestSupport; import org.springframework.integration.websocket.TomcatWebSocketTestServer; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandlingException; import org.springframework.messaging.PollableChannel; import org.springframework.messaging.converter.MappingJackson2MessageConverter; import org.springframework.messaging.converter.MessageConversionException; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.broker.SubscriptionRegistry; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaders; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ErrorMessage; import org.springframework.messaging.support.MessageBuilder; import org.springframework.scheduling.TaskScheduler; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.util.MultiValueMap; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.messaging.SessionSubscribeEvent; import org.springframework.web.socket.messaging.WebSocketStompClient; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.sockjs.client.SockJsClient; import org.springframework.web.socket.sockjs.client.Transport; import org.springframework.web.socket.sockjs.client.WebSocketTransport; /** * @author Artem Bilan * @since 4.2 */ @ContextConfiguration(classes = StompInboundChannelAdapterWebSocketIntegrationTests.ContextConfiguration.class) @RunWith(SpringJUnit4ClassRunner.class) @DirtiesContext public class StompInboundChannelAdapterWebSocketIntegrationTests extends LogAdjustingTestSupport { @Value("#{server.serverContext}") private ConfigurableApplicationContext serverContext; @Autowired @Qualifier("stompInputChannel") private PollableChannel stompInputChannel; @Autowired @Qualifier("errorChannel") private PollableChannel errorChannel; @Autowired @Qualifier("stompEvents") private PollableChannel stompEvents; @Autowired private StompInboundChannelAdapter stompInboundChannelAdapter; public StompInboundChannelAdapterWebSocketIntegrationTests() { super("org.springframework", "org.springframework.integration.stomp"); } @Test public void testWebSocketStompClient() throws Exception { Message<?> eventMessage = this.stompEvents.receive(10000); assertNotNull(eventMessage); assertThat(eventMessage.getPayload(), instanceOf(StompSessionConnectedEvent.class)); Message<?> receive = this.stompEvents.receive(10000); assertNotNull(receive); assertThat(receive.getPayload(), instanceOf(StompReceiptEvent.class)); StompReceiptEvent stompReceiptEvent = (StompReceiptEvent) receive.getPayload(); assertEquals(StompCommand.SUBSCRIBE, stompReceiptEvent.getStompCommand()); assertEquals("/topic/myTopic", stompReceiptEvent.getDestination()); waitForSubscribe("/topic/myTopic"); SimpMessagingTemplate messagingTemplate = this.serverContext.getBean("brokerMessagingTemplate", SimpMessagingTemplate.class); StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.create(StompCommand.MESSAGE); stompHeaderAccessor.setContentType(MediaType.APPLICATION_JSON); stompHeaderAccessor.updateStompCommandAsServerMessage(); stompHeaderAccessor.setLeaveMutable(true); messagingTemplate.send("/topic/myTopic", MessageBuilder.createMessage("{\"foo\": \"bar\"}".getBytes(), stompHeaderAccessor.getMessageHeaders())); receive = this.stompInputChannel.receive(10000); assertNotNull(receive); assertThat(receive.getPayload(), instanceOf(Map.class)); @SuppressWarnings("unchecked") Map<String, String> payload = (Map<String, String>) receive.getPayload(); String foo = payload.get("foo"); assertNotNull(foo); assertEquals("bar", foo); this.stompInboundChannelAdapter.removeDestination("/topic/myTopic"); waitForUnsubscribe("/topic/myTopic"); messagingTemplate.convertAndSend("/topic/myTopic", "foo"); receive = this.errorChannel.receive(100); assertNull(receive); this.stompInboundChannelAdapter.addDestination("/topic/myTopic"); receive = this.stompEvents.receive(10000); assertNotNull(receive); waitForSubscribe("/topic/myTopic"); messagingTemplate.convertAndSend("/topic/myTopic", "foo"); receive = this.errorChannel.receive(10000); assertNotNull(receive); assertThat(receive, instanceOf(ErrorMessage.class)); ErrorMessage errorMessage = (ErrorMessage) receive; Throwable throwable = errorMessage.getPayload(); assertThat(throwable, instanceOf(MessageHandlingException.class)); assertThat(throwable.getCause(), instanceOf(MessageConversionException.class)); assertThat(throwable.getMessage(), containsString("No suitable converter, payloadType=interface java.util.Map")); this.serverContext.close(); eventMessage = this.stompEvents.receive(10000); assertNotNull(eventMessage); assertThat(eventMessage.getPayload(), instanceOf(StompConnectionFailedEvent.class)); this.serverContext.refresh(); do { eventMessage = this.stompEvents.receive(10000); assertNotNull(eventMessage); } while (!(eventMessage.getPayload() instanceof StompSessionConnectedEvent)); waitForSubscribe("/topic/myTopic"); messagingTemplate = this.serverContext.getBean("brokerMessagingTemplate", SimpMessagingTemplate.class); messagingTemplate.convertAndSend("/topic/myTopic", "foo"); receive = this.errorChannel.receive(10000); assertNotNull(receive); assertEquals(0, ((QueueChannel) this.errorChannel).getQueueSize()); } private void waitForSubscribe(String destination) throws InterruptedException { SimpleBrokerMessageHandler serverBrokerMessageHandler = this.serverContext.getBean("simpleBrokerMessageHandler", SimpleBrokerMessageHandler.class); SubscriptionRegistry subscriptionRegistry = serverBrokerMessageHandler.getSubscriptionRegistry(); int n = 0; while (!containsDestination(destination, subscriptionRegistry) && n++ < 100) { Thread.sleep(100); } assertTrue("The subscription for the '" + destination + "' destination hasn't been registered", n < 100); } private void waitForUnsubscribe(String destination) throws InterruptedException { SimpleBrokerMessageHandler serverBrokerMessageHandler = this.serverContext.getBean("simpleBrokerMessageHandler", SimpleBrokerMessageHandler.class); SubscriptionRegistry subscriptionRegistry = serverBrokerMessageHandler.getSubscriptionRegistry(); int n = 0; while (containsDestination(destination, subscriptionRegistry) && n++ < 100) { Thread.sleep(100); } assertTrue("The subscription for the '" + destination + "' destination hasn't been registered", n < 100); } private boolean containsDestination(String destination, SubscriptionRegistry subscriptionRegistry) { StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.create(StompCommand.MESSAGE); stompHeaderAccessor.setDestination(destination); Message<byte[]> message = MessageBuilder.createMessage(new byte[0], stompHeaderAccessor.toMessageHeaders()); MultiValueMap<String, String> subscriptions = subscriptionRegistry.findSubscriptions(message); return !subscriptions.isEmpty(); } // STOMP Client @Configuration @EnableIntegration public static class ContextConfiguration { @Bean public TomcatWebSocketTestServer server() { return new TomcatWebSocketTestServer(ServerConfig.class); } @Bean public WebSocketClient webSocketClient() { return new SockJsClient(Collections.<Transport>singletonList(new WebSocketTransport(new StandardWebSocketClient()))); } @Bean public WebSocketStompClient stompClient(TaskScheduler taskScheduler) { WebSocketStompClient webSocketStompClient = new WebSocketStompClient(webSocketClient()); webSocketStompClient.setMessageConverter(new MappingJackson2MessageConverter()); webSocketStompClient.setTaskScheduler(taskScheduler); return webSocketStompClient; } @Bean public StompSessionManager stompSessionManager(WebSocketStompClient stompClient) { WebSocketStompSessionManager webSocketStompSessionManager = new WebSocketStompSessionManager(stompClient, server().getWsBaseUrl() + "/ws"); webSocketStompSessionManager.setAutoReceipt(true); webSocketStompSessionManager.setRecoveryInterval(1000); WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders(); handshakeHeaders.setOrigin("http://foo.com"); webSocketStompSessionManager.setHandshakeHeaders(handshakeHeaders); StompHeaders stompHeaders = new StompHeaders(); stompHeaders.setHeartbeat(new long[] {10000, 10000}); webSocketStompSessionManager.setConnectHeaders(stompHeaders); return webSocketStompSessionManager; } @Bean public PollableChannel stompInputChannel() { return new QueueChannel(); } @Bean public PollableChannel errorChannel() { return new QueueChannel(); } @Bean public StompInboundChannelAdapter stompInboundChannelAdapter(StompSessionManager stompSessionFactory) { StompInboundChannelAdapter adapter = new StompInboundChannelAdapter(stompSessionFactory, "/topic/myTopic"); adapter.setPayloadType(Map.class); adapter.setOutputChannel(stompInputChannel()); adapter.setErrorChannel(errorChannel()); return adapter; } @Bean public PollableChannel stompEvents() { return new QueueChannel(); } @Bean public ApplicationListener<ApplicationEvent> stompEventListener() { ApplicationEventListeningMessageProducer producer = new ApplicationEventListeningMessageProducer(); producer.setEventTypes(StompIntegrationEvent.class); producer.setOutputChannel(stompEvents()); return producer; } } // WebSocket Server part @Configuration @EnableWebSocketMessageBroker static class ServerConfig extends AbstractWebSocketMessageBrokerConfigurer { @Bean public DefaultHandshakeHandler handshakeHandler() { return new DefaultHandshakeHandler(new TomcatRequestUpgradeStrategy()); } @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/ws") .setHandshakeHandler(handshakeHandler()) .setAllowedOrigins("http://foo.com") .addInterceptors(new HandshakeInterceptor() { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { return request.getHeaders().getOrigin() != null; } @Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) { } }) .withSockJS(); } @Override public void configureMessageBroker(MessageBrokerRegistry configurer) { configurer.setApplicationDestinationPrefixes("/app") .enableSimpleBroker("/topic", "/queue"); } //SimpleBrokerMessageHandler doesn't support RECEIPT frame, hence we emulate it this way @Bean public ApplicationListener<SessionSubscribeEvent> webSocketEventListener( final AbstractSubscribableChannel clientOutboundChannel) { return event -> { Message<byte[]> message = event.getMessage(); StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.wrap(message); if (stompHeaderAccessor.getReceipt() != null) { stompHeaderAccessor.setHeader("stompCommand", StompCommand.RECEIPT); stompHeaderAccessor.setReceiptId(stompHeaderAccessor.getReceipt()); clientOutboundChannel.send( MessageBuilder.createMessage(new byte[0], stompHeaderAccessor.getMessageHeaders())); } }; } } }