/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.aggregator; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.springframework.beans.factory.BeanFactory; import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.handler.AbstractMessageHandler; import org.springframework.integration.store.MessageGroup; import org.springframework.integration.store.SimpleMessageGroupFactory; import org.springframework.integration.store.SimpleMessageStore; import org.springframework.integration.support.MessageBuilder; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandlingException; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.GenericMessage; import org.springframework.util.StopWatch; /** * @author Mark Fisher * @author Marius Bogoevici * @author Iwein Fuld * @author Gary Russell * @author Artem Bilan */ public class AggregatorTests { private static final Log logger = LogFactory.getLog(AggregatorTests.class); private AggregatingMessageHandler aggregator; private final SimpleMessageStore store = new SimpleMessageStore(50); private final List<MessageGroupExpiredEvent> expiryEvents = new ArrayList<MessageGroupExpiredEvent>(); @Before public void configureAggregator() { this.aggregator = new AggregatingMessageHandler(new MultiplyingProcessor(), store); this.aggregator.setBeanFactory(mock(BeanFactory.class)); this.aggregator.setApplicationEventPublisher(event -> expiryEvents.add((MessageGroupExpiredEvent) event)); this.aggregator.setBeanName("testAggregator"); this.aggregator.afterPropertiesSet(); expiryEvents.clear(); } @Test public void testAggPerf() throws InterruptedException, ExecutionException, TimeoutException { AggregatingMessageHandler handler = new AggregatingMessageHandler(new DefaultAggregatingMessageGroupProcessor()); handler.setCorrelationStrategy(message -> "foo"); handler.setReleaseStrategy(new MessageCountReleaseStrategy(60000)); handler.setExpireGroupsUponCompletion(true); handler.setSendPartialResultOnExpiry(true); DirectChannel outputChannel = new DirectChannel(); handler.setOutputChannel(outputChannel); final CompletableFuture<Collection<?>> resultFuture = new CompletableFuture<>(); outputChannel.subscribe(message -> { Collection<?> payload = (Collection<?>) message.getPayload(); logger.warn("Received " + payload.size()); resultFuture.complete(payload); }); SimpleMessageStore store = new SimpleMessageStore(); SimpleMessageGroupFactory messageGroupFactory = new SimpleMessageGroupFactory(SimpleMessageGroupFactory.GroupType.BLOCKING_QUEUE); store.setMessageGroupFactory(messageGroupFactory); handler.setMessageStore(store); Message<?> message = new GenericMessage<String>("foo"); StopWatch stopwatch = new StopWatch(); stopwatch.start(); for (int i = 0; i < 120000; i++) { if (i % 10000 == 0) { stopwatch.stop(); logger.warn("Sent " + i + " in " + stopwatch.getTotalTimeSeconds() + " (10k in " + stopwatch.getLastTaskTimeMillis() + "ms)"); stopwatch.start(); } handler.handleMessage(message); } stopwatch.stop(); logger.warn("Sent " + 120000 + " in " + stopwatch.getTotalTimeSeconds() + " (10k in " + stopwatch.getLastTaskTimeMillis() + "ms)"); Collection<?> result = resultFuture.get(10, TimeUnit.SECONDS); assertNotNull(result); assertEquals(60000, result.size()); } @Test public void testAggPerfDefaultPartial() throws InterruptedException, ExecutionException, TimeoutException { AggregatingMessageHandler handler = new AggregatingMessageHandler(new DefaultAggregatingMessageGroupProcessor()); handler.setCorrelationStrategy(message -> "foo"); handler.setReleasePartialSequences(true); DirectChannel outputChannel = new DirectChannel(); handler.setOutputChannel(outputChannel); final CompletableFuture<Collection<?>> resultFuture = new CompletableFuture<>(); outputChannel.subscribe(message -> { Collection<?> payload = (Collection<?>) message.getPayload(); logger.warn("Received " + payload.size()); resultFuture.complete(payload); }); SimpleMessageStore store = new SimpleMessageStore(); SimpleMessageGroupFactory messageGroupFactory = new SimpleMessageGroupFactory(SimpleMessageGroupFactory.GroupType.BLOCKING_QUEUE); store.setMessageGroupFactory(messageGroupFactory); handler.setMessageStore(store); StopWatch stopwatch = new StopWatch(); stopwatch.start(); for (int i = 0; i < 120000; i++) { if (i % 10000 == 0) { stopwatch.stop(); logger.warn("Sent " + i + " in " + stopwatch.getTotalTimeSeconds() + " (10k in " + stopwatch.getLastTaskTimeMillis() + "ms)"); stopwatch.start(); } handler.handleMessage(MessageBuilder.withPayload("foo") .setSequenceSize(120000) .setSequenceNumber(i + 1) .build()); } stopwatch.stop(); logger.warn("Sent " + 120000 + " in " + stopwatch.getTotalTimeSeconds() + " (10k in " + stopwatch.getLastTaskTimeMillis() + "ms)"); Collection<?> result = resultFuture.get(10, TimeUnit.SECONDS); assertNotNull(result); assertEquals(120000, result.size()); assertThat(stopwatch.getTotalTimeSeconds(), lessThan(60.0)); // actually < 2.0, was many minutes } @Test public void testCustomAggPerf() throws InterruptedException, ExecutionException, TimeoutException { class CustomHandler extends AbstractMessageHandler { // custom aggregator, only handles a single correlation private final ReentrantLock lock = new ReentrantLock(); private final Collection<Message<?>> messages = new ArrayList<Message<?>>(60000); private final MessageChannel outputChannel; private CustomHandler(MessageChannel outputChannel) { this.outputChannel = outputChannel; } @Override public void handleMessageInternal(Message<?> requestMessage) { lock.lock(); try { this.messages.add(requestMessage); if (this.messages.size() == 60000) { List<Object> payloads = new ArrayList<Object>(this.messages.size()); for (Message<?> message : this.messages) { payloads.add(message.getPayload()); } this.messages.clear(); outputChannel.send(getMessageBuilderFactory().withPayload(payloads) .copyHeaders(requestMessage.getHeaders()) .build()); } } finally { lock.unlock(); } } } DirectChannel outputChannel = new DirectChannel(); CustomHandler handler = new CustomHandler(outputChannel); final CompletableFuture<Collection<?>> resultFuture = new CompletableFuture<>(); outputChannel.subscribe(message -> { Collection<?> payload = (Collection<?>) message.getPayload(); logger.warn("Received " + payload.size()); resultFuture.complete(payload); }); Message<?> message = new GenericMessage<String>("foo"); StopWatch stopwatch = new StopWatch(); stopwatch.start(); for (int i = 0; i < 120000; i++) { if (i % 10000 == 0) { stopwatch.stop(); logger.warn("Sent " + i + " in " + stopwatch.getTotalTimeSeconds() + " (10k in " + stopwatch.getLastTaskTimeMillis() + "ms)"); stopwatch.start(); } handler.handleMessage(message); } stopwatch.stop(); logger.warn("Sent " + 120000 + " in " + stopwatch.getTotalTimeSeconds() + " (10k in " + stopwatch.getLastTaskTimeMillis() + "ms)"); Collection<?> result = resultFuture.get(10, TimeUnit.SECONDS); assertNotNull(result); assertEquals(60000, result.size()); } @Test public void testCompleteGroupWithinTimeout() throws InterruptedException { QueueChannel replyChannel = new QueueChannel(); Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null); Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null); Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null); this.aggregator.handleMessage(message1); this.aggregator.handleMessage(message2); this.aggregator.handleMessage(message3); Message<?> reply = replyChannel.receive(10000); assertNotNull(reply); assertEquals(reply.getPayload(), 105); } @Test public void testShouldNotSendPartialResultOnTimeoutByDefault() throws InterruptedException { QueueChannel discardChannel = new QueueChannel(); this.aggregator.setDiscardChannel(discardChannel); QueueChannel replyChannel = new QueueChannel(); Message<?> message = createMessage(3, "ABC", 2, 1, replyChannel, null); this.aggregator.handleMessage(message); this.store.expireMessageGroups(-10000); Message<?> reply = replyChannel.receive(1000); assertNull("No message should have been sent normally", reply); Message<?> discardedMessage = discardChannel.receive(1000); assertNotNull("A message should have been discarded", discardedMessage); assertEquals(message, discardedMessage); assertEquals(1, expiryEvents.size()); assertSame(this.aggregator, expiryEvents.get(0).getSource()); assertEquals("ABC", this.expiryEvents.get(0).getGroupId()); assertEquals(1, this.expiryEvents.get(0).getMessageCount()); assertEquals(true, this.expiryEvents.get(0).isDiscarded()); } @Test public void testShouldSendPartialResultOnTimeoutTrue() throws InterruptedException { this.aggregator.setSendPartialResultOnExpiry(true); QueueChannel replyChannel = new QueueChannel(); Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null); Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null); this.aggregator.handleMessage(message1); this.aggregator.handleMessage(message2); this.store.expireMessageGroups(-10000); Message<?> reply = replyChannel.receive(1000); assertNotNull("A reply message should have been received", reply); assertEquals(15, reply.getPayload()); assertEquals(1, expiryEvents.size()); assertSame(this.aggregator, expiryEvents.get(0).getSource()); assertEquals("ABC", this.expiryEvents.get(0).getGroupId()); assertEquals(2, this.expiryEvents.get(0).getMessageCount()); assertEquals(false, this.expiryEvents.get(0).isDiscarded()); Message<?> message3 = createMessage(5, "ABC", 3, 3, replyChannel, null); this.aggregator.handleMessage(message3); assertEquals(1, this.store.getMessageGroup("ABC").size()); } @Test public void testGroupRemainsAfterTimeout() throws InterruptedException { this.aggregator.setSendPartialResultOnExpiry(true); this.aggregator.setExpireGroupsUponTimeout(false); QueueChannel replyChannel = new QueueChannel(); QueueChannel discardChannel = new QueueChannel(); this.aggregator.setDiscardChannel(discardChannel); Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null); Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null); this.aggregator.handleMessage(message1); this.aggregator.handleMessage(message2); this.store.expireMessageGroups(-10000); Message<?> reply = replyChannel.receive(1000); assertNotNull("A reply message should have been received", reply); assertEquals(15, reply.getPayload()); assertEquals(1, expiryEvents.size()); assertSame(this.aggregator, expiryEvents.get(0).getSource()); assertEquals("ABC", this.expiryEvents.get(0).getGroupId()); assertEquals(2, this.expiryEvents.get(0).getMessageCount()); assertEquals(false, this.expiryEvents.get(0).isDiscarded()); assertEquals(0, this.store.getMessageGroup("ABC").size()); Message<?> message3 = createMessage(5, "ABC", 3, 3, replyChannel, null); this.aggregator.handleMessage(message3); assertEquals(0, this.store.getMessageGroup("ABC").size()); Message<?> discardedMessage = discardChannel.receive(1000); assertNotNull("A message should have been discarded", discardedMessage); assertSame(message3, discardedMessage); } @Test public void testMultipleGroupsSimultaneously() throws InterruptedException { QueueChannel replyChannel1 = new QueueChannel(); QueueChannel replyChannel2 = new QueueChannel(); Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel1, null); Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel1, null); Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel1, null); Message<?> message4 = createMessage(11, "XYZ", 3, 1, replyChannel2, null); Message<?> message5 = createMessage(13, "XYZ", 3, 2, replyChannel2, null); Message<?> message6 = createMessage(17, "XYZ", 3, 3, replyChannel2, null); aggregator.handleMessage(message1); aggregator.handleMessage(message5); aggregator.handleMessage(message3); aggregator.handleMessage(message6); aggregator.handleMessage(message4); aggregator.handleMessage(message2); @SuppressWarnings("unchecked") Message<Integer> reply1 = (Message<Integer>) replyChannel1.receive(1000); assertNotNull(reply1); assertThat(reply1.getPayload(), is(105)); @SuppressWarnings("unchecked") Message<Integer> reply2 = (Message<Integer>) replyChannel2.receive(1000); assertNotNull(reply2); assertThat(reply2.getPayload(), is(2431)); } @Test @Ignore // dropped backwards compatibility for setting capacity limit (it's always Integer.MAX_VALUE) public void testTrackedCorrelationIdsCapacityAtLimit() { QueueChannel replyChannel = new QueueChannel(); QueueChannel discardChannel = new QueueChannel(); this.aggregator.setDiscardChannel(discardChannel); this.aggregator.handleMessage(createMessage(1, 1, 1, 1, replyChannel, null)); assertEquals(1, replyChannel.receive(1000).getPayload()); this.aggregator.handleMessage(createMessage(3, 2, 1, 1, replyChannel, null)); assertEquals(3, replyChannel.receive(1000).getPayload()); this.aggregator.handleMessage(createMessage(4, 3, 1, 1, replyChannel, null)); assertEquals(4, replyChannel.receive(1000).getPayload()); // next message with same correllation ID is discarded this.aggregator.handleMessage(createMessage(2, 1, 1, 1, replyChannel, null)); assertEquals(2, discardChannel.receive(1000).getPayload()); } @Test @Ignore // dropped backwards compatibility for setting capacity limit (it's always Integer.MAX_VALUE) public void testTrackedCorrelationIdsCapacityPassesLimit() { QueueChannel replyChannel = new QueueChannel(); QueueChannel discardChannel = new QueueChannel(); this.aggregator.setDiscardChannel(discardChannel); this.aggregator.handleMessage(createMessage(1, 1, 1, 1, replyChannel, null)); assertEquals(1, replyChannel.receive(1000).getPayload()); this.aggregator.handleMessage(createMessage(2, 2, 1, 1, replyChannel, null)); assertEquals(2, replyChannel.receive(1000).getPayload()); this.aggregator.handleMessage(createMessage(3, 3, 1, 1, replyChannel, null)); assertEquals(3, replyChannel.receive(1000).getPayload()); this.aggregator.handleMessage(createMessage(4, 4, 1, 1, replyChannel, null)); assertEquals(4, replyChannel.receive(1000).getPayload()); this.aggregator.handleMessage(createMessage(5, 1, 1, 1, replyChannel, null)); assertEquals(5, replyChannel.receive(1000).getPayload()); assertNull(discardChannel.receive(0)); } @Test(expected = MessageHandlingException.class) public void testExceptionThrownIfNoCorrelationId() throws InterruptedException { Message<?> message = createMessage(3, null, 2, 1, new QueueChannel(), null); this.aggregator.handleMessage(message); } @Test public void testAdditionalMessageAfterCompletion() throws InterruptedException { QueueChannel replyChannel = new QueueChannel(); Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null); Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null); Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null); Message<?> message4 = createMessage(7, "ABC", 3, 3, replyChannel, null); this.aggregator.handleMessage(message1); this.aggregator.handleMessage(message2); this.aggregator.handleMessage(message3); this.aggregator.handleMessage(message4); Message<?> reply = replyChannel.receive(10000); assertNotNull("A message should be aggregated", reply); assertThat(((Integer) reply.getPayload()), is(105)); } @Test public void shouldRejectDuplicatedSequenceNumbers() throws InterruptedException { QueueChannel replyChannel = new QueueChannel(); Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null); Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null); Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null); Message<?> message4 = createMessage(7, "ABC", 3, 3, replyChannel, null); this.aggregator.setReleaseStrategy(new SequenceSizeReleaseStrategy()); this.aggregator.handleMessage(message1); this.aggregator.handleMessage(message3); // duplicated sequence number, either message3 or message4 should be rejected this.aggregator.handleMessage(message4); this.aggregator.handleMessage(message2); Message<?> reply = replyChannel.receive(10000); assertNotNull("A message should be aggregated", reply); assertThat(((Integer) reply.getPayload()), is(105)); } private static Message<?> createMessage(Object payload, Object correlationId, int sequenceSize, int sequenceNumber, MessageChannel replyChannel, String predefinedId) { MessageBuilder<Object> builder = MessageBuilder.withPayload(payload).setCorrelationId(correlationId) .setSequenceSize(sequenceSize).setSequenceNumber(sequenceNumber).setReplyChannel(replyChannel); if (predefinedId != null) { builder.setHeader(MessageHeaders.ID, predefinedId); } return builder.build(); } private class MultiplyingProcessor implements MessageGroupProcessor { MultiplyingProcessor() { super(); } @Override public Object processMessageGroup(MessageGroup group) { Integer product = 1; for (Message<?> message : group.getMessages()) { product *= (Integer) message.getPayload(); } return product; } } }