/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.ReconnectStrategy;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.messaging.tcp.TcpConnectionHandler;
import org.springframework.messaging.tcp.TcpOperations;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureTask;
/**
* Unit tests for StompBrokerRelayMessageHandler.
*
* @author Rossen Stoyanchev
*/
public class StompBrokerRelayMessageHandlerTests {
private StompBrokerRelayMessageHandler brokerRelay;
private StubMessageChannel outboundChannel;
private StubTcpOperations tcpClient;
@Before
public void setup() {
this.outboundChannel = new StubMessageChannel();
this.brokerRelay = new StompBrokerRelayMessageHandler(new StubMessageChannel(),
this.outboundChannel, new StubMessageChannel(), Arrays.asList("/topic")) {
@Override
protected void startInternal() {
publishBrokerAvailableEvent(); // Force this, since we'll never actually connect
super.startInternal();
}
};
this.tcpClient = new StubTcpOperations();
this.brokerRelay.setTcpClient(this.tcpClient);
}
@Test
public void virtualHost() throws Exception {
this.brokerRelay.setVirtualHost("ABC");
this.brokerRelay.start();
this.brokerRelay.handleMessage(connectMessage("sess1", "joe"));
assertEquals(2, this.tcpClient.getSentMessages().size());
StompHeaderAccessor headers1 = this.tcpClient.getSentHeaders(0);
assertEquals(StompCommand.CONNECT, headers1.getCommand());
assertEquals(StompBrokerRelayMessageHandler.SYSTEM_SESSION_ID, headers1.getSessionId());
assertEquals("ABC", headers1.getHost());
StompHeaderAccessor headers2 = this.tcpClient.getSentHeaders(1);
assertEquals(StompCommand.CONNECT, headers2.getCommand());
assertEquals("sess1", headers2.getSessionId());
assertEquals("ABC", headers2.getHost());
}
@Test
public void loginAndPasscode() throws Exception {
this.brokerRelay.setSystemLogin("syslogin");
this.brokerRelay.setSystemPasscode("syspasscode");
this.brokerRelay.setClientLogin("clientlogin");
this.brokerRelay.setClientPasscode("clientpasscode");
this.brokerRelay.start();
this.brokerRelay.handleMessage(connectMessage("sess1", "joe"));
assertEquals(2, this.tcpClient.getSentMessages().size());
StompHeaderAccessor headers1 = this.tcpClient.getSentHeaders(0);
assertEquals(StompCommand.CONNECT, headers1.getCommand());
assertEquals("syslogin", headers1.getLogin());
assertEquals("syspasscode", headers1.getPasscode());
StompHeaderAccessor headers2 = this.tcpClient.getSentHeaders(1);
assertEquals(StompCommand.CONNECT, headers2.getCommand());
assertEquals("clientlogin", headers2.getLogin());
assertEquals("clientpasscode", headers2.getPasscode());
}
@Test
public void destinationExcluded() throws Exception {
this.brokerRelay.start();
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setSessionId("sess1");
headers.setDestination("/user/daisy/foo");
this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
assertEquals(1, this.tcpClient.getSentMessages().size());
StompHeaderAccessor headers1 = this.tcpClient.getSentHeaders(0);
assertEquals(StompCommand.CONNECT, headers1.getCommand());
assertEquals(StompBrokerRelayMessageHandler.SYSTEM_SESSION_ID, headers1.getSessionId());
}
@Test
public void messageFromBrokerIsEnriched() throws Exception {
this.brokerRelay.start();
this.brokerRelay.handleMessage(connectMessage("sess1", "joe"));
assertEquals(2, this.tcpClient.getSentMessages().size());
assertEquals(StompCommand.CONNECT, this.tcpClient.getSentHeaders(0).getCommand());
assertEquals(StompCommand.CONNECT, this.tcpClient.getSentHeaders(1).getCommand());
this.tcpClient.handleMessage(message(StompCommand.MESSAGE, null, null, null));
Message<byte[]> message = this.outboundChannel.getMessages().get(0);
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertEquals("sess1", accessor.getSessionId());
assertEquals("joe", accessor.getUser().getName());
}
// SPR-12820
@Test
public void connectWhenBrokerNotAvailable() throws Exception {
this.brokerRelay.start();
this.brokerRelay.stopInternal();
this.brokerRelay.handleMessage(connectMessage("sess1", "joe"));
Message<byte[]> message = this.outboundChannel.getMessages().get(0);
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertEquals(StompCommand.ERROR, accessor.getCommand());
assertEquals("sess1", accessor.getSessionId());
assertEquals("joe", accessor.getUser().getName());
assertEquals("Broker not available.", accessor.getMessage());
}
@Test
public void sendAfterBrokerUnavailable() throws Exception {
this.brokerRelay.start();
assertEquals(1, this.brokerRelay.getConnectionCount());
this.brokerRelay.handleMessage(connectMessage("sess1", "joe"));
assertEquals(2, this.brokerRelay.getConnectionCount());
this.brokerRelay.stopInternal();
this.brokerRelay.handleMessage(message(StompCommand.SEND, "sess1", "joe", "/foo"));
assertEquals(1, this.brokerRelay.getConnectionCount());
Message<byte[]> message = this.outboundChannel.getMessages().get(0);
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertEquals(StompCommand.ERROR, accessor.getCommand());
assertEquals("sess1", accessor.getSessionId());
assertEquals("joe", accessor.getUser().getName());
assertEquals("Broker not available.", accessor.getMessage());
}
@Test
public void systemSubscription() throws Exception {
MessageHandler handler = mock(MessageHandler.class);
this.brokerRelay.setSystemSubscriptions(Collections.singletonMap("/topic/foo", handler));
this.brokerRelay.start();
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
accessor.setLeaveMutable(true);
MessageHeaders headers = accessor.getMessageHeaders();
this.tcpClient.handleMessage(MessageBuilder.createMessage(new byte[0], headers));
assertEquals(2, this.tcpClient.getSentMessages().size());
assertEquals(StompCommand.CONNECT, this.tcpClient.getSentHeaders(0).getCommand());
assertEquals(StompCommand.SUBSCRIBE, this.tcpClient.getSentHeaders(1).getCommand());
assertEquals("/topic/foo", this.tcpClient.getSentHeaders(1).getDestination());
Message<byte[]> message = message(StompCommand.MESSAGE, null, null, "/topic/foo");
this.tcpClient.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(handler).handleMessage(captor.capture());
assertSame(message, captor.getValue());
}
private Message<byte[]> connectMessage(String sessionId, String user) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
headers.setUser(new TestPrincipal(user));
return MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
}
private Message<byte[]> message(StompCommand command, String sessionId, String user, String destination) {
StompHeaderAccessor accessor = StompHeaderAccessor.create(command);
if (sessionId != null) {
accessor.setSessionId(sessionId);
}
if (user != null) {
accessor.setUser(new TestPrincipal(user));
}
if (destination != null) {
accessor.setDestination(destination);
}
accessor.setLeaveMutable(true);
return MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders());
}
private static ListenableFutureTask<Void> getVoidFuture() {
ListenableFutureTask<Void> futureTask = new ListenableFutureTask<>(new Callable<Void>() {
@Override
public Void call() throws Exception {
return null;
}
});
futureTask.run();
return futureTask;
}
private static ListenableFutureTask<Boolean> getBooleanFuture() {
ListenableFutureTask<Boolean> futureTask = new ListenableFutureTask<>(new Callable<Boolean>() {
@Override
public Boolean call() throws Exception {
return null;
}
});
futureTask.run();
return futureTask;
}
private static class StubTcpOperations implements TcpOperations<byte[]> {
private StubTcpConnection connection = new StubTcpConnection();
private TcpConnectionHandler<byte[]> connectionHandler;
public List<Message<byte[]>> getSentMessages() {
return this.connection.getMessages();
}
public StompHeaderAccessor getSentHeaders(int index) {
assertTrue("Size: " + getSentMessages().size(), getSentMessages().size() > index);
Message<byte[]> message = getSentMessages().get(index);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertNotNull(accessor);
return accessor;
}
@Override
public ListenableFuture<Void> connect(TcpConnectionHandler<byte[]> handler) {
this.connectionHandler = handler;
handler.afterConnected(this.connection);
return getVoidFuture();
}
@Override
public ListenableFuture<Void> connect(TcpConnectionHandler<byte[]> handler, ReconnectStrategy strategy) {
this.connectionHandler = handler;
handler.afterConnected(this.connection);
return getVoidFuture();
}
@Override
public ListenableFuture<Void> shutdown() {
return getVoidFuture();
}
public void handleMessage(Message<byte[]> message) {
this.connectionHandler.handleMessage(message);
}
}
private static class StubTcpConnection implements TcpConnection<byte[]> {
private final List<Message<byte[]>> messages = new ArrayList<>();
public List<Message<byte[]>> getMessages() {
return this.messages;
}
@Override
public ListenableFuture<Void> send(Message<byte[]> message) {
this.messages.add(message);
return getVoidFuture();
}
@Override
public void onReadInactivity(Runnable runnable, long duration) {
}
@Override
public void onWriteInactivity(Runnable runnable, long duration) {
}
@Override
public void close() {
}
}
}