/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.messaging.simp.stomp; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.activemq.broker.BrokerService; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; import org.springframework.messaging.converter.StringMessageConverter; import org.springframework.messaging.simp.stomp.StompSession.Subscription; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.Assert; import org.springframework.util.SocketUtils; import org.springframework.util.concurrent.ListenableFuture; import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; /** * Integration tests for {@link ReactorNettyTcpStompClient}. * * @author Rossen Stoyanchev */ public class ReactorNettyTcpStompClientTests { private static final Log logger = LogFactory.getLog(ReactorNettyTcpStompClientTests.class); @Rule public final TestName testName = new TestName(); private BrokerService activeMQBroker; private ReactorNettyTcpStompClient client; @Before public void setUp() throws Exception { logger.debug("Setting up before '" + this.testName.getMethodName() + "'"); int port = SocketUtils.findAvailableTcpPort(61613); this.activeMQBroker = new BrokerService(); this.activeMQBroker.addConnector("stomp://127.0.0.1:" + port); this.activeMQBroker.setStartAsync(false); this.activeMQBroker.setPersistent(false); this.activeMQBroker.setUseJmx(false); this.activeMQBroker.getSystemUsage().getMemoryUsage().setLimit(1024 * 1024 * 5); this.activeMQBroker.getSystemUsage().getTempUsage().setLimit(1024 * 1024 * 5); this.activeMQBroker.start(); ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); taskScheduler.afterPropertiesSet(); this.client = new ReactorNettyTcpStompClient("127.0.0.1", port); this.client.setMessageConverter(new StringMessageConverter()); this.client.setTaskScheduler(taskScheduler); } @After public void tearDown() throws Exception { try { this.client.shutdown(); } catch (Throwable ex) { logger.error("Failed to shut client", ex); } final CountDownLatch latch = new CountDownLatch(1); this.activeMQBroker.addShutdownHook(latch::countDown); logger.debug("Stopping ActiveMQ broker and will await shutdown"); this.activeMQBroker.stop(); if (!latch.await(5, TimeUnit.SECONDS)) { logger.debug("ActiveMQ broker did not shut in the expected time."); } } @Test public void publishSubscribe() throws Exception { String destination = "/topic/foo"; ConsumingHandler consumingHandler1 = new ConsumingHandler(destination); ListenableFuture<StompSession> consumerFuture1 = this.client.connect(consumingHandler1); ConsumingHandler consumingHandler2 = new ConsumingHandler(destination); ListenableFuture<StompSession> consumerFuture2 = this.client.connect(consumingHandler2); assertTrue(consumingHandler1.awaitForSubscriptions(5000)); assertTrue(consumingHandler2.awaitForSubscriptions(5000)); ProducingHandler producingHandler = new ProducingHandler(); producingHandler.addToSend(destination, "foo1"); producingHandler.addToSend(destination, "foo2"); ListenableFuture<StompSession> producerFuture = this.client.connect(producingHandler); assertTrue(consumingHandler1.awaitForMessageCount(2, 5000)); assertThat(consumingHandler1.getReceived(), containsInAnyOrder("foo1", "foo2")); assertTrue(consumingHandler2.awaitForMessageCount(2, 5000)); assertThat(consumingHandler2.getReceived(), containsInAnyOrder("foo1", "foo2")); consumerFuture1.get().disconnect(); consumerFuture2.get().disconnect(); producerFuture.get().disconnect(); } private static class LoggingSessionHandler extends StompSessionHandlerAdapter { @Override public void handleException(StompSession session, StompCommand command, StompHeaders headers, byte[] payload, Throwable ex) { logger.error(command + " " + headers, ex); } @Override public void handleFrame(StompHeaders headers, Object payload) { logger.error("STOMP error frame " + headers + " payload=" + payload); } @Override public void handleTransportError(StompSession session, Throwable exception) { logger.error(exception); } } private static class ConsumingHandler extends LoggingSessionHandler { private final List<String> topics; private final CountDownLatch subscriptionLatch; private final List<String> received = new ArrayList<>(); public ConsumingHandler(String... topics) { Assert.notEmpty(topics, "Topics must not be empty"); this.topics = Arrays.asList(topics); this.subscriptionLatch = new CountDownLatch(this.topics.size()); } public List<String> getReceived() { return this.received; } @Override public void afterConnected(StompSession session, StompHeaders connectedHeaders) { for (String topic : this.topics) { session.setAutoReceipt(true); Subscription subscription = session.subscribe(topic, new StompFrameHandler() { @Override public Type getPayloadType(StompHeaders headers) { return String.class; } @Override public void handleFrame(StompHeaders headers, Object payload) { received.add((String) payload); } }); subscription.addReceiptTask(subscriptionLatch::countDown); } } public boolean awaitForSubscriptions(long millisToWait) throws InterruptedException { if (logger.isDebugEnabled()) { logger.debug("Awaiting for subscription receipts"); } return this.subscriptionLatch.await(millisToWait, TimeUnit.MILLISECONDS); } public boolean awaitForMessageCount(int expected, long millisToWait) throws InterruptedException { if (logger.isDebugEnabled()) { logger.debug("Awaiting for message count: " + expected); } long startTime = System.currentTimeMillis(); while (this.received.size() < expected) { Thread.sleep(500); if ((System.currentTimeMillis() - startTime) > millisToWait) { return false; } } return true; } } private static class ProducingHandler extends LoggingSessionHandler { private final List<String> topics = new ArrayList<>(); private final List<Object> payloads = new ArrayList<>(); public ProducingHandler addToSend(String topic, Object payload) { this.topics.add(topic); this.payloads.add(payload); return this; } @Override public void afterConnected(StompSession session, StompHeaders connectedHeaders) { for (int i = 0; i < this.topics.size(); i++) { session.send(this.topics.get(i), this.payloads.get(i)); } } } }