/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ScheduledFuture;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSessionHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
/**
* Unit tests for {@link WebSocketStompClient}.
*
* @author Rossen Stoyanchev
*/
public class WebSocketStompClientTests {
@Mock
private TaskScheduler taskScheduler;
@Mock
private ConnectionHandlingStompSession stompSession;
@Mock
private WebSocketSession webSocketSession;
private TestWebSocketStompClient stompClient;
private ArgumentCaptor<WebSocketHandler> webSocketHandlerCaptor;
private SettableListenableFuture<WebSocketSession> handshakeFuture;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
WebSocketClient webSocketClient = mock(WebSocketClient.class);
this.stompClient = new TestWebSocketStompClient(webSocketClient);
this.stompClient.setTaskScheduler(this.taskScheduler);
this.stompClient.setStompSession(this.stompSession);
this.webSocketHandlerCaptor = ArgumentCaptor.forClass(WebSocketHandler.class);
this.handshakeFuture = new SettableListenableFuture<>();
when(webSocketClient.doHandshake(this.webSocketHandlerCaptor.capture(), any(), any(URI.class)))
.thenReturn(this.handshakeFuture);
}
@Test
public void webSocketHandshakeFailure() throws Exception {
connect();
IllegalStateException handshakeFailure = new IllegalStateException("simulated exception");
this.handshakeFuture.setException(handshakeFailure);
verify(this.stompSession).afterConnectFailure(same(handshakeFailure));
}
@Test
public void webSocketConnectionEstablished() throws Exception {
connect().afterConnectionEstablished(this.webSocketSession);
verify(this.stompSession).afterConnected(notNull());
}
@Test
public void webSocketTransportError() throws Exception {
IllegalStateException exception = new IllegalStateException("simulated exception");
connect().handleTransportError(this.webSocketSession, exception);
verify(this.stompSession).handleFailure(same(exception));
}
@Test
public void webSocketConnectionClosed() throws Exception {
connect().afterConnectionClosed(this.webSocketSession, CloseStatus.NORMAL);
verify(this.stompSession).afterConnectionClosed();
}
@Test
@SuppressWarnings({"unchecked", "rawtypes"})
public void handleWebSocketMessage() throws Exception {
String text = "SEND\na:alpha\n\nMessage payload\0";
connect().handleMessage(this.webSocketSession, new TextMessage(text));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(captor.capture());
Message<byte[]> message = captor.getValue();
assertNotNull(message);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap());
assertEquals(StompCommand.SEND, accessor.getCommand());
assertEquals("alpha", headers.getFirst("a"));
assertEquals("Message payload", new String(message.getPayload(), StandardCharsets.UTF_8));
}
@Test
@SuppressWarnings({"unchecked", "rawtypes"})
public void handleWebSocketMessageSplitAcrossTwoMessage() throws Exception {
WebSocketHandler webSocketHandler = connect();
String part1 = "SEND\na:alpha\n\nMessage";
webSocketHandler.handleMessage(this.webSocketSession, new TextMessage(part1));
verifyNoMoreInteractions(this.stompSession);
String part2 = " payload\0";
webSocketHandler.handleMessage(this.webSocketSession, new TextMessage(part2));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(captor.capture());
Message<byte[]> message = captor.getValue();
assertNotNull(message);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap());
assertEquals(StompCommand.SEND, accessor.getCommand());
assertEquals("alpha", headers.getFirst("a"));
assertEquals("Message payload", new String(message.getPayload(), StandardCharsets.UTF_8));
}
@Test
@SuppressWarnings({"unchecked", "rawtypes"})
public void handleWebSocketMessageBinary() throws Exception {
String text = "SEND\na:alpha\n\nMessage payload\0";
connect().handleMessage(this.webSocketSession, new BinaryMessage(text.getBytes(StandardCharsets.UTF_8)));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(captor.capture());
Message<byte[]> message = captor.getValue();
assertNotNull(message);
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap());
assertEquals(StompCommand.SEND, accessor.getCommand());
assertEquals("alpha", headers.getFirst("a"));
assertEquals("Message payload", new String(message.getPayload(), StandardCharsets.UTF_8));
}
@Test
public void handleWebSocketMessagePong() throws Exception {
connect().handleMessage(this.webSocketSession, new PongMessage());
verifyNoMoreInteractions(this.stompSession);
}
@Test
public void sendWebSocketMessage() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/topic/foo");
byte[] payload = "payload".getBytes(StandardCharsets.UTF_8);
getTcpConnection().send(MessageBuilder.createMessage(payload, accessor.getMessageHeaders()));
ArgumentCaptor<TextMessage> textMessageCaptor = ArgumentCaptor.forClass(TextMessage.class);
verify(this.webSocketSession).sendMessage(textMessageCaptor.capture());
TextMessage textMessage = textMessageCaptor.getValue();
assertNotNull(textMessage);
assertEquals("SEND\ndestination:/topic/foo\ncontent-length:7\n\npayload\0", textMessage.getPayload());
}
@Test
public void sendWebSocketBinary() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/b");
accessor.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM);
byte[] payload = "payload".getBytes(StandardCharsets.UTF_8);
getTcpConnection().send(MessageBuilder.createMessage(payload, accessor.getMessageHeaders()));
ArgumentCaptor<BinaryMessage> binaryMessageCaptor = ArgumentCaptor.forClass(BinaryMessage.class);
verify(this.webSocketSession).sendMessage(binaryMessageCaptor.capture());
BinaryMessage binaryMessage = binaryMessageCaptor.getValue();
assertNotNull(binaryMessage);
assertEquals("SEND\ndestination:/b\ncontent-type:application/octet-stream\ncontent-length:7\n\npayload\0",
new String(binaryMessage.getPayload().array(), StandardCharsets.UTF_8));
}
@Test
public void heartbeatDefaultValue() throws Exception {
WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class));
assertArrayEquals(new long[] {0, 0}, stompClient.getDefaultHeartbeat());
StompHeaders connectHeaders = stompClient.processConnectHeaders(null);
assertArrayEquals(new long[] {0, 0}, connectHeaders.getHeartbeat());
}
@Test
public void heartbeatDefaultValueWithScheduler() throws Exception {
WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class));
stompClient.setTaskScheduler(mock(TaskScheduler.class));
assertArrayEquals(new long[] {10000, 10000}, stompClient.getDefaultHeartbeat());
StompHeaders connectHeaders = stompClient.processConnectHeaders(null);
assertArrayEquals(new long[] {10000, 10000}, connectHeaders.getHeartbeat());
}
@Test
public void heartbeatDefaultValueSetWithoutScheduler() throws Exception {
WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class));
stompClient.setDefaultHeartbeat(new long[] {5, 5});
try {
stompClient.processConnectHeaders(null);
fail("Expected IllegalStateException");
}
catch (IllegalStateException ex) {
// ignore
}
}
@Test
public void readInactivityAfterDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 2;
tcpConnection.onReadInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 10);
}
@Test
public void readInactivityBeforeDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 10000;
tcpConnection.onReadInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 0);
}
@Test
public void writeInactivityAfterDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 2;
tcpConnection.onWriteInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 10);
}
@Test
public void writeInactivityBeforeDelayHasElapsed() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
Runnable runnable = mock(Runnable.class);
long delay = 1000;
tcpConnection.onWriteInactivity(runnable, delay);
testInactivityTaskScheduling(runnable, delay, 0);
}
@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void cancelInactivityTasks() throws Exception {
TcpConnection<byte[]> tcpConnection = getTcpConnection();
ScheduledFuture future = mock(ScheduledFuture.class);
when(this.taskScheduler.scheduleWithFixedDelay(any(), eq(1L))).thenReturn(future);
tcpConnection.onReadInactivity(mock(Runnable.class), 2L);
tcpConnection.onWriteInactivity(mock(Runnable.class), 2L);
this.webSocketHandlerCaptor.getValue().afterConnectionClosed(this.webSocketSession, CloseStatus.NORMAL);
verify(future, times(2)).cancel(true);
verifyNoMoreInteractions(future);
}
private WebSocketHandler connect() {
this.stompClient.connect("/foo", mock(StompSessionHandler.class));
verify(this.stompSession).getSessionFuture();
verifyNoMoreInteractions(this.stompSession);
WebSocketHandler webSocketHandler = this.webSocketHandlerCaptor.getValue();
assertNotNull(webSocketHandler);
return webSocketHandler;
}
@SuppressWarnings("unchecked")
private TcpConnection<byte[]> getTcpConnection() throws Exception {
WebSocketHandler webSocketHandler = connect();
webSocketHandler.afterConnectionEstablished(this.webSocketSession);
return (TcpConnection<byte[]>) webSocketHandler;
}
private void testInactivityTaskScheduling(Runnable runnable, long delay, long sleepTime)
throws InterruptedException {
ArgumentCaptor<Runnable> inactivityTaskCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(this.taskScheduler).scheduleWithFixedDelay(inactivityTaskCaptor.capture(), eq(delay/2));
verifyNoMoreInteractions(this.taskScheduler);
if (sleepTime > 0) {
Thread.sleep(sleepTime);
}
Runnable inactivityTask = inactivityTaskCaptor.getValue();
assertNotNull(inactivityTask);
inactivityTask.run();
if (sleepTime > 0) {
verify(runnable).run();
}
else {
verifyNoMoreInteractions(runnable);
}
}
private static class TestWebSocketStompClient extends WebSocketStompClient {
private ConnectionHandlingStompSession stompSession;
public TestWebSocketStompClient(WebSocketClient webSocketClient) {
super(webSocketClient);
}
public void setStompSession(ConnectionHandlingStompSession stompSession) {
this.stompSession = stompSession;
}
@Override
protected ConnectionHandlingStompSession createSession(StompHeaders headers, StompSessionHandler handler) {
return this.stompSession;
}
}
}