/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.messaging.simp.broker; import java.security.Principal; import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledFuture; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; import org.springframework.scheduling.TaskScheduler; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; /** * Unit tests for SimpleBrokerMessageHandler. * * @author Rossen Stoyanchev * @since 4.0 */ @SuppressWarnings("unchecked") public class SimpleBrokerMessageHandlerTests { private SimpleBrokerMessageHandler messageHandler; @Mock private SubscribableChannel clientInboundChannel; @Mock private MessageChannel clientOutboundChannel; @Mock private SubscribableChannel brokerChannel; @Mock private TaskScheduler taskScheduler; @Captor ArgumentCaptor<Message<?>> messageCaptor; @Before public void setup() { MockitoAnnotations.initMocks(this); this.messageHandler = new SimpleBrokerMessageHandler(this.clientInboundChannel, this.clientOutboundChannel, this.brokerChannel, Collections.emptyList()); } @Test public void subcribePublish() { this.messageHandler.start(); this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub1", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub2", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub3", "/bar")); this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub1", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub2", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub3", "/bar")); this.messageHandler.handleMessage(createMessage("/foo", "message1")); this.messageHandler.handleMessage(createMessage("/bar", "message2")); verify(this.clientOutboundChannel, times(6)).send(this.messageCaptor.capture()); assertTrue(messageCaptured("sess1", "sub1", "/foo")); assertTrue(messageCaptured("sess1", "sub2", "/foo")); assertTrue(messageCaptured("sess2", "sub1", "/foo")); assertTrue(messageCaptured("sess2", "sub2", "/foo")); assertTrue(messageCaptured("sess1", "sub3", "/bar")); assertTrue(messageCaptured("sess2", "sub3", "/bar")); } @Test public void subcribeDisconnectPublish() { String sess1 = "sess1"; String sess2 = "sess2"; this.messageHandler.start(); this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub1", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub2", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub3", "/bar")); this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub1", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub2", "/foo")); this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub3", "/bar")); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); headers.setSessionId(sess1); headers.setUser(new TestPrincipal("joe")); Message<byte[]> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); this.messageHandler.handleMessage(message); this.messageHandler.handleMessage(createMessage("/foo", "message1")); this.messageHandler.handleMessage(createMessage("/bar", "message2")); verify(this.clientOutboundChannel, times(4)).send(this.messageCaptor.capture()); Message<?> captured = this.messageCaptor.getAllValues().get(0); assertEquals(SimpMessageType.DISCONNECT_ACK, SimpMessageHeaderAccessor.getMessageType(captured.getHeaders())); assertSame(message, captured.getHeaders().get(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER)); assertEquals(sess1, SimpMessageHeaderAccessor.getSessionId(captured.getHeaders())); assertEquals("joe", SimpMessageHeaderAccessor.getUser(captured.getHeaders()).getName()); assertTrue(messageCaptured(sess2, "sub1", "/foo")); assertTrue(messageCaptured(sess2, "sub2", "/foo")); assertTrue(messageCaptured(sess2, "sub3", "/bar")); } @Test public void connect() { this.messageHandler.start(); String id = "sess1"; Message<String> connectMessage = createConnectMessage(id, new TestPrincipal("joe"), null); this.messageHandler.setTaskScheduler(this.taskScheduler); this.messageHandler.handleMessage(connectMessage); verify(this.clientOutboundChannel, times(1)).send(this.messageCaptor.capture()); Message<?> connectAckMessage = this.messageCaptor.getValue(); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.wrap(connectAckMessage); assertEquals(connectMessage, connectAckHeaders.getHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER)); assertEquals(id, connectAckHeaders.getSessionId()); assertEquals("joe", connectAckHeaders.getUser().getName()); assertArrayEquals(new long[] {10000, 10000}, SimpMessageHeaderAccessor.getHeartbeat(connectAckHeaders.getMessageHeaders())); } @Test public void heartbeatValueWithAndWithoutTaskScheduler() throws Exception { assertNull(this.messageHandler.getHeartbeatValue()); this.messageHandler.setTaskScheduler(this.taskScheduler); assertNotNull(this.messageHandler.getHeartbeatValue()); assertArrayEquals(new long[] {10000, 10000}, this.messageHandler.getHeartbeatValue()); } @Test(expected = IllegalArgumentException.class) public void startWithHeartbeatValueWithoutTaskScheduler() throws Exception { this.messageHandler.setHeartbeatValue(new long[] {10000, 10000}); this.messageHandler.start(); } @SuppressWarnings("unchecked") @Test public void startAndStopWithHeartbeatValue() throws Exception { ScheduledFuture future = mock(ScheduledFuture.class); when(this.taskScheduler.scheduleWithFixedDelay(any(Runnable.class), eq(15000L))).thenReturn(future); this.messageHandler.setTaskScheduler(this.taskScheduler); this.messageHandler.setHeartbeatValue(new long[] {15000, 16000}); this.messageHandler.start(); verify(this.taskScheduler).scheduleWithFixedDelay(any(Runnable.class), eq(15000L)); verifyNoMoreInteractions(this.taskScheduler, future); this.messageHandler.stop(); verify(future).cancel(true); verifyNoMoreInteractions(future); } @SuppressWarnings("unchecked") @Test public void startWithOneZeroHeartbeatValue() throws Exception { this.messageHandler.setTaskScheduler(this.taskScheduler); this.messageHandler.setHeartbeatValue(new long[] {0, 10000}); this.messageHandler.start(); verify(this.taskScheduler).scheduleWithFixedDelay(any(Runnable.class), eq(10000L)); } @Test public void readInactivity() throws Exception { this.messageHandler.setHeartbeatValue(new long[] {0, 1}); this.messageHandler.setTaskScheduler(this.taskScheduler); this.messageHandler.start(); ArgumentCaptor<Runnable> taskCaptor = ArgumentCaptor.forClass(Runnable.class); verify(this.taskScheduler).scheduleWithFixedDelay(taskCaptor.capture(), eq(1L)); Runnable heartbeatTask = taskCaptor.getValue(); assertNotNull(heartbeatTask); String id = "sess1"; TestPrincipal user = new TestPrincipal("joe"); Message<String> connectMessage = createConnectMessage(id, user, new long[] {1, 0}); this.messageHandler.handleMessage(connectMessage); Thread.sleep(10); heartbeatTask.run(); verify(this.clientOutboundChannel, atLeast(2)).send(this.messageCaptor.capture()); List<Message<?>> messages = this.messageCaptor.getAllValues(); assertEquals(2, messages.size()); MessageHeaders headers = messages.get(0).getHeaders(); assertEquals(SimpMessageType.CONNECT_ACK, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER)); headers = messages.get(1).getHeaders(); assertEquals(SimpMessageType.DISCONNECT_ACK, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER)); assertEquals(id, headers.get(SimpMessageHeaderAccessor.SESSION_ID_HEADER)); assertEquals(user, headers.get(SimpMessageHeaderAccessor.USER_HEADER)); } @Test public void writeInactivity() throws Exception { this.messageHandler.setHeartbeatValue(new long[] {1, 0}); this.messageHandler.setTaskScheduler(this.taskScheduler); this.messageHandler.start(); ArgumentCaptor<Runnable> taskCaptor = ArgumentCaptor.forClass(Runnable.class); verify(this.taskScheduler).scheduleWithFixedDelay(taskCaptor.capture(), eq(1L)); Runnable heartbeatTask = taskCaptor.getValue(); assertNotNull(heartbeatTask); String id = "sess1"; TestPrincipal user = new TestPrincipal("joe"); Message<String> connectMessage = createConnectMessage(id, user, new long[] {0, 1}); this.messageHandler.handleMessage(connectMessage); Thread.sleep(10); heartbeatTask.run(); verify(this.clientOutboundChannel, times(2)).send(this.messageCaptor.capture()); List<Message<?>> messages = this.messageCaptor.getAllValues(); assertEquals(2, messages.size()); MessageHeaders headers = messages.get(0).getHeaders(); assertEquals(SimpMessageType.CONNECT_ACK, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER)); headers = messages.get(1).getHeaders(); assertEquals(SimpMessageType.HEARTBEAT, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER)); assertEquals(id, headers.get(SimpMessageHeaderAccessor.SESSION_ID_HEADER)); assertEquals(user, headers.get(SimpMessageHeaderAccessor.USER_HEADER)); } @Test public void readWriteIntervalCalculation() throws Exception { this.messageHandler.setHeartbeatValue(new long[] {1, 1}); this.messageHandler.setTaskScheduler(this.taskScheduler); this.messageHandler.start(); ArgumentCaptor<Runnable> taskCaptor = ArgumentCaptor.forClass(Runnable.class); verify(this.taskScheduler).scheduleWithFixedDelay(taskCaptor.capture(), eq(1L)); Runnable heartbeatTask = taskCaptor.getValue(); assertNotNull(heartbeatTask); String id = "sess1"; TestPrincipal user = new TestPrincipal("joe"); Message<String> connectMessage = createConnectMessage(id, user, new long[] {10000, 10000}); this.messageHandler.handleMessage(connectMessage); Thread.sleep(10); heartbeatTask.run(); verify(this.clientOutboundChannel, times(1)).send(this.messageCaptor.capture()); List<Message<?>> messages = this.messageCaptor.getAllValues(); assertEquals(1, messages.size()); assertEquals(SimpMessageType.CONNECT_ACK, messages.get(0).getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER)); } private Message<String> createSubscriptionMessage(String sessionId, String subcriptionId, String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); headers.setSubscriptionId(subcriptionId); headers.setDestination(destination); headers.setSessionId(sessionId); return MessageBuilder.createMessage("", headers.getMessageHeaders()); } private Message<String> createConnectMessage(String sessionId, Principal user, long[] heartbeat) { SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); accessor.setSessionId(sessionId); accessor.setUser(user); accessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, heartbeat); return MessageBuilder.createMessage("", accessor.getMessageHeaders()); } private Message<String> createMessage(String destination, String payload) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); headers.setDestination(destination); return MessageBuilder.createMessage(payload, headers.getMessageHeaders()); } private boolean messageCaptured(String sessionId, String subcriptionId, String destination) { for (Message<?> message : this.messageCaptor.getAllValues()) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (sessionId.equals(headers.getSessionId())) { if (subcriptionId.equals(headers.getSubscriptionId())) { if (destination.equals(headers.getDestination())) { return true; } } } } return false; } }