/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.socket.messaging; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.concurrent.ScheduledFuture; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.messaging.Message; import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaders; import org.springframework.messaging.simp.stomp.StompSessionHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.tcp.TcpConnection; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.MimeTypeUtils; import org.springframework.util.concurrent.SettableListenableFuture; import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.client.WebSocketClient; import static org.junit.Assert.*; import static org.mockito.Mockito.*; /** * Unit tests for {@link WebSocketStompClient}. * * @author Rossen Stoyanchev */ public class WebSocketStompClientTests { @Mock private TaskScheduler taskScheduler; @Mock private ConnectionHandlingStompSession stompSession; @Mock private WebSocketSession webSocketSession; private TestWebSocketStompClient stompClient; private ArgumentCaptor<WebSocketHandler> webSocketHandlerCaptor; private SettableListenableFuture<WebSocketSession> handshakeFuture; @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); WebSocketClient webSocketClient = mock(WebSocketClient.class); this.stompClient = new TestWebSocketStompClient(webSocketClient); this.stompClient.setTaskScheduler(this.taskScheduler); this.stompClient.setStompSession(this.stompSession); this.webSocketHandlerCaptor = ArgumentCaptor.forClass(WebSocketHandler.class); this.handshakeFuture = new SettableListenableFuture<>(); when(webSocketClient.doHandshake(this.webSocketHandlerCaptor.capture(), any(), any(URI.class))) .thenReturn(this.handshakeFuture); } @Test public void webSocketHandshakeFailure() throws Exception { connect(); IllegalStateException handshakeFailure = new IllegalStateException("simulated exception"); this.handshakeFuture.setException(handshakeFailure); verify(this.stompSession).afterConnectFailure(same(handshakeFailure)); } @Test public void webSocketConnectionEstablished() throws Exception { connect().afterConnectionEstablished(this.webSocketSession); verify(this.stompSession).afterConnected(notNull()); } @Test public void webSocketTransportError() throws Exception { IllegalStateException exception = new IllegalStateException("simulated exception"); connect().handleTransportError(this.webSocketSession, exception); verify(this.stompSession).handleFailure(same(exception)); } @Test public void webSocketConnectionClosed() throws Exception { connect().afterConnectionClosed(this.webSocketSession, CloseStatus.NORMAL); verify(this.stompSession).afterConnectionClosed(); } @Test @SuppressWarnings({"unchecked", "rawtypes"}) public void handleWebSocketMessage() throws Exception { String text = "SEND\na:alpha\n\nMessage payload\0"; connect().handleMessage(this.webSocketSession, new TextMessage(text)); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); verify(this.stompSession).handleMessage(captor.capture()); Message<byte[]> message = captor.getValue(); assertNotNull(message); StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap()); assertEquals(StompCommand.SEND, accessor.getCommand()); assertEquals("alpha", headers.getFirst("a")); assertEquals("Message payload", new String(message.getPayload(), StandardCharsets.UTF_8)); } @Test @SuppressWarnings({"unchecked", "rawtypes"}) public void handleWebSocketMessageSplitAcrossTwoMessage() throws Exception { WebSocketHandler webSocketHandler = connect(); String part1 = "SEND\na:alpha\n\nMessage"; webSocketHandler.handleMessage(this.webSocketSession, new TextMessage(part1)); verifyNoMoreInteractions(this.stompSession); String part2 = " payload\0"; webSocketHandler.handleMessage(this.webSocketSession, new TextMessage(part2)); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); verify(this.stompSession).handleMessage(captor.capture()); Message<byte[]> message = captor.getValue(); assertNotNull(message); StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap()); assertEquals(StompCommand.SEND, accessor.getCommand()); assertEquals("alpha", headers.getFirst("a")); assertEquals("Message payload", new String(message.getPayload(), StandardCharsets.UTF_8)); } @Test @SuppressWarnings({"unchecked", "rawtypes"}) public void handleWebSocketMessageBinary() throws Exception { String text = "SEND\na:alpha\n\nMessage payload\0"; connect().handleMessage(this.webSocketSession, new BinaryMessage(text.getBytes(StandardCharsets.UTF_8))); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); verify(this.stompSession).handleMessage(captor.capture()); Message<byte[]> message = captor.getValue(); assertNotNull(message); StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); StompHeaders headers = StompHeaders.readOnlyStompHeaders(accessor.toNativeHeaderMap()); assertEquals(StompCommand.SEND, accessor.getCommand()); assertEquals("alpha", headers.getFirst("a")); assertEquals("Message payload", new String(message.getPayload(), StandardCharsets.UTF_8)); } @Test public void handleWebSocketMessagePong() throws Exception { connect().handleMessage(this.webSocketSession, new PongMessage()); verifyNoMoreInteractions(this.stompSession); } @Test public void sendWebSocketMessage() throws Exception { StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); accessor.setDestination("/topic/foo"); byte[] payload = "payload".getBytes(StandardCharsets.UTF_8); getTcpConnection().send(MessageBuilder.createMessage(payload, accessor.getMessageHeaders())); ArgumentCaptor<TextMessage> textMessageCaptor = ArgumentCaptor.forClass(TextMessage.class); verify(this.webSocketSession).sendMessage(textMessageCaptor.capture()); TextMessage textMessage = textMessageCaptor.getValue(); assertNotNull(textMessage); assertEquals("SEND\ndestination:/topic/foo\ncontent-length:7\n\npayload\0", textMessage.getPayload()); } @Test public void sendWebSocketBinary() throws Exception { StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); accessor.setDestination("/b"); accessor.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM); byte[] payload = "payload".getBytes(StandardCharsets.UTF_8); getTcpConnection().send(MessageBuilder.createMessage(payload, accessor.getMessageHeaders())); ArgumentCaptor<BinaryMessage> binaryMessageCaptor = ArgumentCaptor.forClass(BinaryMessage.class); verify(this.webSocketSession).sendMessage(binaryMessageCaptor.capture()); BinaryMessage binaryMessage = binaryMessageCaptor.getValue(); assertNotNull(binaryMessage); assertEquals("SEND\ndestination:/b\ncontent-type:application/octet-stream\ncontent-length:7\n\npayload\0", new String(binaryMessage.getPayload().array(), StandardCharsets.UTF_8)); } @Test public void heartbeatDefaultValue() throws Exception { WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class)); assertArrayEquals(new long[] {0, 0}, stompClient.getDefaultHeartbeat()); StompHeaders connectHeaders = stompClient.processConnectHeaders(null); assertArrayEquals(new long[] {0, 0}, connectHeaders.getHeartbeat()); } @Test public void heartbeatDefaultValueWithScheduler() throws Exception { WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class)); stompClient.setTaskScheduler(mock(TaskScheduler.class)); assertArrayEquals(new long[] {10000, 10000}, stompClient.getDefaultHeartbeat()); StompHeaders connectHeaders = stompClient.processConnectHeaders(null); assertArrayEquals(new long[] {10000, 10000}, connectHeaders.getHeartbeat()); } @Test public void heartbeatDefaultValueSetWithoutScheduler() throws Exception { WebSocketStompClient stompClient = new WebSocketStompClient(mock(WebSocketClient.class)); stompClient.setDefaultHeartbeat(new long[] {5, 5}); try { stompClient.processConnectHeaders(null); fail("Expected IllegalStateException"); } catch (IllegalStateException ex) { // ignore } } @Test public void readInactivityAfterDelayHasElapsed() throws Exception { TcpConnection<byte[]> tcpConnection = getTcpConnection(); Runnable runnable = mock(Runnable.class); long delay = 2; tcpConnection.onReadInactivity(runnable, delay); testInactivityTaskScheduling(runnable, delay, 10); } @Test public void readInactivityBeforeDelayHasElapsed() throws Exception { TcpConnection<byte[]> tcpConnection = getTcpConnection(); Runnable runnable = mock(Runnable.class); long delay = 10000; tcpConnection.onReadInactivity(runnable, delay); testInactivityTaskScheduling(runnable, delay, 0); } @Test public void writeInactivityAfterDelayHasElapsed() throws Exception { TcpConnection<byte[]> tcpConnection = getTcpConnection(); Runnable runnable = mock(Runnable.class); long delay = 2; tcpConnection.onWriteInactivity(runnable, delay); testInactivityTaskScheduling(runnable, delay, 10); } @Test public void writeInactivityBeforeDelayHasElapsed() throws Exception { TcpConnection<byte[]> tcpConnection = getTcpConnection(); Runnable runnable = mock(Runnable.class); long delay = 1000; tcpConnection.onWriteInactivity(runnable, delay); testInactivityTaskScheduling(runnable, delay, 0); } @Test @SuppressWarnings({"rawtypes", "unchecked"}) public void cancelInactivityTasks() throws Exception { TcpConnection<byte[]> tcpConnection = getTcpConnection(); ScheduledFuture future = mock(ScheduledFuture.class); when(this.taskScheduler.scheduleWithFixedDelay(any(), eq(1L))).thenReturn(future); tcpConnection.onReadInactivity(mock(Runnable.class), 2L); tcpConnection.onWriteInactivity(mock(Runnable.class), 2L); this.webSocketHandlerCaptor.getValue().afterConnectionClosed(this.webSocketSession, CloseStatus.NORMAL); verify(future, times(2)).cancel(true); verifyNoMoreInteractions(future); } private WebSocketHandler connect() { this.stompClient.connect("/foo", mock(StompSessionHandler.class)); verify(this.stompSession).getSessionFuture(); verifyNoMoreInteractions(this.stompSession); WebSocketHandler webSocketHandler = this.webSocketHandlerCaptor.getValue(); assertNotNull(webSocketHandler); return webSocketHandler; } @SuppressWarnings("unchecked") private TcpConnection<byte[]> getTcpConnection() throws Exception { WebSocketHandler webSocketHandler = connect(); webSocketHandler.afterConnectionEstablished(this.webSocketSession); return (TcpConnection<byte[]>) webSocketHandler; } private void testInactivityTaskScheduling(Runnable runnable, long delay, long sleepTime) throws InterruptedException { ArgumentCaptor<Runnable> inactivityTaskCaptor = ArgumentCaptor.forClass(Runnable.class); verify(this.taskScheduler).scheduleWithFixedDelay(inactivityTaskCaptor.capture(), eq(delay/2)); verifyNoMoreInteractions(this.taskScheduler); if (sleepTime > 0) { Thread.sleep(sleepTime); } Runnable inactivityTask = inactivityTaskCaptor.getValue(); assertNotNull(inactivityTask); inactivityTask.run(); if (sleepTime > 0) { verify(runnable).run(); } else { verifyNoMoreInteractions(runnable); } } private static class TestWebSocketStompClient extends WebSocketStompClient { private ConnectionHandlingStompSession stompSession; public TestWebSocketStompClient(WebSocketClient webSocketClient) { super(webSocketClient); } public void setStompSession(ConnectionHandlingStompSession stompSession) { this.stompSession = stompSession; } @Override protected ConnectionHandlingStompSession createSession(StompHeaders headers, StompSessionHandler handler) { return this.stompSession; } } }