/*
* Copyright 2014-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.websocket.client;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.integration.annotation.Gateway;
import org.springframework.integration.annotation.IntegrationComponentScan;
import org.springframework.integration.annotation.MessagingGateway;
import org.springframework.integration.annotation.ServiceActivator;
import org.springframework.integration.annotation.Transformer;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.config.EnableIntegration;
import org.springframework.integration.core.MessageProducer;
import org.springframework.integration.event.inbound.ApplicationEventListeningMessageProducer;
import org.springframework.integration.test.support.LogAdjustingTestSupport;
import org.springframework.integration.transformer.ExpressionEvaluatingTransformer;
import org.springframework.integration.websocket.ClientWebSocketContainer;
import org.springframework.integration.websocket.IntegrationWebSocketContainer;
import org.springframework.integration.websocket.TomcatWebSocketTestServer;
import org.springframework.integration.websocket.event.ReceiptEvent;
import org.springframework.integration.websocket.inbound.WebSocketInboundChannelAdapter;
import org.springframework.integration.websocket.outbound.WebSocketOutboundMessageHandler;
import org.springframework.integration.websocket.support.SubProtocolHandlerRegistry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.PollableChannel;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.broker.SubscriptionRegistry;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.stereotype.Controller;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.util.MultiValueMap;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.sockjs.client.SockJsClient;
import org.springframework.web.socket.sockjs.client.Transport;
import org.springframework.web.socket.sockjs.client.WebSocketTransport;
/**
* @author Artem Bilan
* @since 4.1
*/
@ContextConfiguration
@RunWith(SpringJUnit4ClassRunner.class)
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
public class StompIntegrationTests extends LogAdjustingTestSupport {
@Value("#{server.serverContext}")
private ApplicationContext serverContext;
@Autowired
private IntegrationWebSocketContainer clientWebSocketContainer;
@Autowired
@Qualifier("webSocketOutputChannel")
private MessageChannel webSocketOutputChannel;
@Autowired
@Qualifier("webSocketInputChannel")
private QueueChannel webSocketInputChannel;
@Autowired
@Qualifier("webSocketEvents")
private QueueChannel webSocketEvents;
public StompIntegrationTests() {
super("org.springframework", "org.springframework.integration", "org.apache.catalina");
}
@Test
public void sendMessageToController() throws Exception {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
this.webSocketOutputChannel.send(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
Message<?> receive = this.webSocketEvents.receive(20000);
assertNotNull(receive);
Object event = receive.getPayload();
assertThat(event, instanceOf(SessionConnectedEvent.class));
Message<?> connectedMessage = ((SessionConnectedEvent) event).getMessage();
headers = StompHeaderAccessor.wrap(connectedMessage);
assertEquals(StompCommand.CONNECTED, headers.getCommand());
headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setSubscriptionId("sub1");
headers.setDestination("/app/simple");
Message<String> message = MessageBuilder.withPayload("foo").setHeaders(headers).build();
this.webSocketOutputChannel.send(message);
SimpleController controller = this.serverContext.getBean(SimpleController.class);
assertTrue(controller.latch.await(20, TimeUnit.SECONDS));
}
@Test
public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
headers.setDestination("/topic/increment");
headers.setReceipt("myReceipt");
Message<byte[]> message = MessageBuilder.withPayload(ByteBuffer.allocate(0).array())
.setHeaders(headers)
.build();
this.webSocketOutputChannel.send(message);
Message<?> receive = this.webSocketEvents.receive(20000);
assertNotNull(receive);
Object event = receive.getPayload();
assertThat(event, instanceOf(ReceiptEvent.class));
Message<?> receiptMessage = ((ReceiptEvent) event).getMessage();
headers = StompHeaderAccessor.wrap(receiptMessage);
assertEquals(StompCommand.RECEIPT, headers.getCommand());
assertEquals("myReceipt", headers.getReceiptId());
waitForSubscribe("/topic/increment");
headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setSubscriptionId("subs1");
headers.setDestination("/app/increment");
Message<Integer> message2 = MessageBuilder.withPayload(5).setHeaders(headers).build();
this.webSocketOutputChannel.send(message2);
receive = webSocketInputChannel.receive(20000);
assertNotNull(receive);
assertEquals("6", receive.getPayload());
}
@Test
public void sendMessageToBrokerAndReceiveReplyViaTopic() throws Exception {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
headers.setDestination("/topic/foo");
Message<byte[]> message = MessageBuilder.withPayload(ByteBuffer.allocate(0).array())
.setHeaders(headers)
.build();
headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setSubscriptionId("subs1");
headers.setDestination("/topic/foo");
Message<Integer> message2 = MessageBuilder.withPayload(10).setHeaders(headers).build();
this.webSocketOutputChannel.send(message);
waitForSubscribe("/topic/foo");
this.webSocketOutputChannel.send(message2);
Message<?> receive = webSocketInputChannel.receive(20000);
assertNotNull(receive);
assertEquals("10", receive.getPayload());
}
@Test
public void sendSubscribeToControllerAndReceiveReply() throws Exception {
String destHeader = "/app/number";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
headers.setDestination(destHeader);
Message<byte[]> message = MessageBuilder.withPayload(ByteBuffer.allocate(0).array())
.setHeaders(headers)
.build();
this.webSocketOutputChannel.send(message);
Message<?> receive = webSocketInputChannel.receive(20000);
assertNotNull(receive);
StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.wrap(receive);
assertEquals("Expected STOMP destination=/app/number, got " + stompHeaderAccessor,
destHeader, stompHeaderAccessor.getDestination());
Object payload = receive.getPayload();
assertEquals("Expected STOMP Payload=42, got " + payload, "42", payload);
}
@Test
public void handleExceptionAndSendToUser() throws Exception {
String destHeader = "/user/queue/error";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
headers.setDestination(destHeader);
Message<byte[]> message = MessageBuilder.withPayload(ByteBuffer.allocate(0).array())
.setHeaders(headers)
.build();
headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setSubscriptionId("subs1");
headers.setDestination("/app/exception");
Message<String> message2 = MessageBuilder.withPayload("foo").setHeaders(headers).build();
this.webSocketOutputChannel.send(message);
waitForSubscribe("/queue/error-user" + this.clientWebSocketContainer.getSession(null).getId());
this.webSocketOutputChannel.send(message2);
Message<?> receive = webSocketInputChannel.receive(20000);
assertNotNull(receive);
StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.wrap(receive);
assertEquals("Expected STOMP destination=/user/queue/error, got " + stompHeaderAccessor,
destHeader, stompHeaderAccessor.getDestination());
assertEquals("Got error: Bad input", receive.getPayload());
}
@Test
public void sendMessageToGateway() throws Exception {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
headers.setDestination("/user/queue/answer");
Message<byte[]> message = MessageBuilder.withPayload(ByteBuffer.allocate(0).array())
.setHeaders(headers)
.build();
headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setSubscriptionId("subs1");
headers.setDestination("/app/greeting");
Message<String> message2 = MessageBuilder.withPayload("Bob").setHeaders(headers).build();
this.webSocketOutputChannel.send(message);
waitForSubscribe("/queue/answer-user" + this.clientWebSocketContainer.getSession(null).getId());
this.webSocketOutputChannel.send(message2);
Message<?> receive = webSocketInputChannel.receive(20000);
assertNotNull(receive);
assertEquals("Hello Bob", receive.getPayload());
}
private void waitForSubscribe(String destination) throws InterruptedException {
SimpleBrokerMessageHandler serverBrokerMessageHandler =
this.serverContext.getBean("simpleBrokerMessageHandler", SimpleBrokerMessageHandler.class);
SubscriptionRegistry subscriptionRegistry = serverBrokerMessageHandler.getSubscriptionRegistry();
int n = 0;
while (!containsDestination(destination, subscriptionRegistry) && n++ < 100) {
Thread.sleep(100);
}
assertTrue("The subscription for the '" + destination + "' destination hasn't been registered", n < 100);
}
private boolean containsDestination(String destination, SubscriptionRegistry subscriptionRegistry) {
StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.create(StompCommand.MESSAGE);
stompHeaderAccessor.setDestination(destination);
Message<byte[]> message = MessageBuilder.createMessage(new byte[0], stompHeaderAccessor.toMessageHeaders());
MultiValueMap<String, String> subscriptions = subscriptionRegistry.findSubscriptions(message);
return !subscriptions.isEmpty();
}
@Configuration
@EnableIntegration
public static class ContextConfiguration {
@Bean
public TomcatWebSocketTestServer server() {
return new TomcatWebSocketTestServer(ServerConfig.class);
}
@Bean
public WebSocketClient webSocketClient() {
return new SockJsClient(Collections.<Transport>singletonList(new WebSocketTransport(new StandardWebSocketClient())));
}
@Bean
public IntegrationWebSocketContainer clientWebSocketContainer() {
return new ClientWebSocketContainer(webSocketClient(), server().getWsBaseUrl() + "/ws");
}
@Bean
public SubProtocolHandler stompSubProtocolHandler() {
return new StompSubProtocolHandler();
}
@Bean
public MessageChannel webSocketInputChannel() {
return new QueueChannel();
}
@Bean
public MessageChannel webSocketOutputChannel() {
return new DirectChannel();
}
@Bean
public MessageProducer webSocketInboundChannelAdapter() {
WebSocketInboundChannelAdapter webSocketInboundChannelAdapter =
new WebSocketInboundChannelAdapter(clientWebSocketContainer(),
new SubProtocolHandlerRegistry(stompSubProtocolHandler()));
webSocketInboundChannelAdapter.setOutputChannel(webSocketInputChannel());
return webSocketInboundChannelAdapter;
}
@Bean
@ServiceActivator(inputChannel = "webSocketOutputChannel")
public MessageHandler webSocketOutboundMessageHandler() {
return new WebSocketOutboundMessageHandler(clientWebSocketContainer(),
new SubProtocolHandlerRegistry(stompSubProtocolHandler()));
}
@Bean
public PollableChannel webSocketEvents() {
return new QueueChannel();
}
@Bean
@SuppressWarnings("unchecked")
public ApplicationListener<ApplicationEvent> webSocketEventListener() {
ApplicationEventListeningMessageProducer producer = new ApplicationEventListeningMessageProducer();
producer.setEventTypes(AbstractSubProtocolEvent.class);
producer.setOutputChannel(webSocketEvents());
return producer;
}
}
// WebSocket Server part
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Controller
private @interface IntegrationTestController {
}
@IntegrationTestController
static class SimpleController {
private final CountDownLatch latch = new CountDownLatch(1);
@MessageMapping("/simple")
public void handle() {
this.latch.countDown();
}
@MessageMapping("/exception")
public void handleWithError() {
throw new IllegalArgumentException("Bad input");
}
@MessageExceptionHandler
@SendToUser("/queue/error")
public String handleException(IllegalArgumentException ex) {
return "Got error: " + ex.getMessage();
}
}
@IntegrationTestController
static class IncrementController {
@MessageMapping("/increment")
public int handle(int i) {
return i + 1;
}
@SubscribeMapping("/number")
public int number() {
return 42;
}
}
@MessagingGateway
@Controller
interface WebSocketGateway {
@MessageMapping("/greeting")
@SendToUser("/queue/answer")
@Gateway(requestChannel = "greetingChannel")
String greeting(String payload);
}
@Configuration
@EnableWebSocketMessageBroker
@EnableIntegration
@ComponentScan(
basePackageClasses = StompIntegrationTests.class,
useDefaultFilters = false,
includeFilters = @ComponentScan.Filter(IntegrationTestController.class))
@IntegrationComponentScan
static class ServerConfig extends AbstractWebSocketMessageBrokerConfigurer {
private static final ExpressionParser expressionParser = new SpelExpressionParser();
@Bean
public MessageChannel greetingChannel() {
return new DirectChannel();
}
@Bean
@Transformer(inputChannel = "greetingChannel")
public ExpressionEvaluatingTransformer greetingTransformer() {
return new ExpressionEvaluatingTransformer(expressionParser.parseExpression("'Hello ' + payload"));
}
@Bean
public DefaultHandshakeHandler handshakeHandler() {
return new DefaultHandshakeHandler(new TomcatRequestUpgradeStrategy());
}
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").setHandshakeHandler(handshakeHandler()).withSockJS();
}
@Override
public void configureMessageBroker(MessageBrokerRegistry configurer) {
configurer.setApplicationDestinationPrefixes("/app");
configurer.enableSimpleBroker("/topic", "/queue");
}
//SimpleBrokerMessageHandler doesn't support RECEIPT frame, hence we emulate it this way
@Bean
public ApplicationListener<SessionSubscribeEvent> webSocketEventListener(
final AbstractSubscribableChannel clientOutboundChannel) {
return event -> {
Message<byte[]> message = event.getMessage();
StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.wrap(message);
if (stompHeaderAccessor.getReceipt() != null) {
stompHeaderAccessor.setHeader("stompCommand", StompCommand.RECEIPT);
stompHeaderAccessor.setReceiptId(stompHeaderAccessor.getReceipt());
clientOutboundChannel.send(
MessageBuilder.createMessage(new byte[0], stompHeaderAccessor.getMessageHeaders()));
}
};
}
}
}