/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.aggregator;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.store.MessageGroup;
import org.springframework.integration.store.MessageGroupStore;
import org.springframework.integration.store.SimpleMessageStore;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHandlingException;
import org.springframework.messaging.MessageHeaders;
/**
* @author Mark Fisher
* @author Marius Bogoevici
* @author Iwein Fuld
*/
public class ConcurrentAggregatorTests {
private TaskExecutor taskExecutor;
private AggregatingMessageHandler aggregator;
private final MessageGroupStore store = new SimpleMessageStore();
@Before
public void configureAggregator() {
this.taskExecutor = new SimpleAsyncTaskExecutor();
this.aggregator = new AggregatingMessageHandler(new MultiplyingProcessor(), this.store);
this.aggregator.setReleaseStrategy(new SimpleSequenceSizeReleaseStrategy());
}
@Test
public void testCompleteGroupWithinTimeout() throws InterruptedException {
QueueChannel replyChannel = new QueueChannel();
Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null);
CountDownLatch latch = new CountDownLatch(3);
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message1, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message2, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message3, latch));
assertTrue(latch.await(10, TimeUnit.SECONDS));
assertThat(latch.getCount(), is(0L));
Message<?> reply = replyChannel.receive(2000);
assertNotNull(reply);
assertEquals(reply.getPayload(), 105);
}
@Test
@Ignore
// dropped backwards compatibility for duplicate ID's
public void testCompleteGroupWithinTimeoutWithSameId()
throws InterruptedException {
QueueChannel replyChannel = new QueueChannel();
Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel,
"ID#1");
Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel,
"ID#1");
Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel,
"ID#1");
CountDownLatch latch = new CountDownLatch(3);
// for testing the duplication scenario, the messages must be processed
// synchronously
new AggregatorTestTask(this.aggregator, message1, latch).run();
new AggregatorTestTask(this.aggregator, message2, latch).run();
new AggregatorTestTask(this.aggregator, message3, latch).run();
Message<?> reply = replyChannel.receive(1000);
assertNotNull(reply);
assertEquals("123456789", reply.getPayload());
}
@Test
public void testShouldNotSendPartialResultOnTimeoutByDefault()
throws InterruptedException {
QueueChannel discardChannel = new QueueChannel();
this.aggregator.setDiscardChannel(discardChannel);
QueueChannel replyChannel = new QueueChannel();
Message<?> message = createMessage(3, "ABC", 2, 1, replyChannel, null);
CountDownLatch latch = new CountDownLatch(1);
AggregatorTestTask task = new AggregatorTestTask(this.aggregator,
message, latch);
this.taskExecutor.execute(task);
assertTrue(latch.await(10, TimeUnit.SECONDS));
assertEquals("Task should have completed within timeout", 0, latch
.getCount());
Message<?> reply = replyChannel.receive(1000);
assertNull("No message should have been sent normally", reply);
this.store.expireMessageGroups(-10000);
Message<?> discardedMessage = discardChannel.receive(1000);
assertNotNull("A message should have been discarded", discardedMessage);
assertEquals(message, discardedMessage);
}
@Test
public void testShouldSendPartialResultOnTimeoutTrue()
throws InterruptedException {
this.aggregator.setSendPartialResultOnExpiry(true);
QueueChannel replyChannel = new QueueChannel();
Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
CountDownLatch latch = new CountDownLatch(2);
AggregatorTestTask task1 = new AggregatorTestTask(this.aggregator,
message1, latch);
AggregatorTestTask task2 = new AggregatorTestTask(this.aggregator,
message2, latch);
this.taskExecutor.execute(task1);
this.taskExecutor.execute(task2);
assertTrue(latch.await(10, TimeUnit.SECONDS));
assertEquals("handlers should have been invoked within time limit", 0,
latch.getCount());
this.store.expireMessageGroups(-10000);
Message<?> reply = replyChannel.receive(1000);
assertNotNull("A reply message should have been received", reply);
assertEquals(15, reply.getPayload());
assertNull(task1.getException());
assertNull(task2.getException());
}
@Test
public void testMultipleGroupsSimultaneously() throws InterruptedException {
QueueChannel replyChannel1 = new QueueChannel();
QueueChannel replyChannel2 = new QueueChannel();
Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel1, null);
Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel1, null);
Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel1, null);
Message<?> message4 = createMessage(11, "XYZ", 3, 1, replyChannel2,
null);
Message<?> message5 = createMessage(13, "XYZ", 3, 2, replyChannel2,
null);
Message<?> message6 = createMessage(17, "XYZ", 3, 3, replyChannel2,
null);
CountDownLatch latch = new CountDownLatch(6);
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message1, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message6, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message2, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message5, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message3, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message4, latch));
assertTrue(latch.await(10, TimeUnit.SECONDS));
@SuppressWarnings("unchecked")
Message<Integer> reply1 = (Message<Integer>) replyChannel1.receive(1000);
assertNotNull(reply1);
assertThat(reply1.getPayload(), is(105));
@SuppressWarnings("unchecked")
Message<Integer> reply2 = (Message<Integer>) replyChannel2.receive(1000);
assertNotNull(reply2);
assertThat(reply2.getPayload(), is(2431));
}
@Test
@Ignore
// dropped backwards compatibility for setting capacity limit (it's always
// Integer.MAX_VALUE)
public void testTrackedCorrelationIdsCapacityAtLimit() {
QueueChannel replyChannel = new QueueChannel();
QueueChannel discardChannel = new QueueChannel();
// this.aggregator.setTrackedCorrelationIdCapacity(3);
this.aggregator.setDiscardChannel(discardChannel);
this.aggregator.handleMessage(createMessage(1, 1, 1, 1, replyChannel,
null));
assertEquals(1, replyChannel.receive(1000).getPayload());
this.aggregator.handleMessage(createMessage(3, 2, 1, 1, replyChannel,
null));
assertEquals(3, replyChannel.receive(1000).getPayload());
this.aggregator.handleMessage(createMessage(4, 3, 1, 1, replyChannel,
null));
assertEquals(4, replyChannel.receive(1000).getPayload());
// next message with same correlation ID is discarded
this.aggregator.handleMessage(createMessage(2, 1, 1, 1, replyChannel,
null));
assertEquals(2, discardChannel.receive(1000).getPayload());
}
@Test
@Ignore
// dropped backwards compatibility for setting capacity limit (it's always
// Integer.MAX_VALUE)
public void testTrackedCorrelationIdsCapacityPassesLimit() {
QueueChannel replyChannel = new QueueChannel();
QueueChannel discardChannel = new QueueChannel();
// this.aggregator.setTrackedCorrelationIdCapacity(3);
this.aggregator.setDiscardChannel(discardChannel);
this.aggregator.handleMessage(createMessage(1, 1, 1, 1, replyChannel,
null));
assertEquals(1, replyChannel.receive(1000).getPayload());
this.aggregator.handleMessage(createMessage(2, 2, 1, 1, replyChannel,
null));
assertEquals(2, replyChannel.receive(1000).getPayload());
this.aggregator.handleMessage(createMessage(3, 3, 1, 1, replyChannel,
null));
assertEquals(3, replyChannel.receive(1000).getPayload());
this.aggregator.handleMessage(createMessage(4, 4, 1, 1, replyChannel,
null));
assertEquals(4, replyChannel.receive(1000).getPayload());
this.aggregator.handleMessage(createMessage(5, 1, 1, 1, replyChannel,
null));
assertEquals(5, replyChannel.receive(1000).getPayload());
assertNull(discardChannel.receive(0));
}
@Test(expected = MessageHandlingException.class)
public void testExceptionThrownIfNoCorrelationId()
throws InterruptedException {
Message<?> message = createMessage(3, null, 2, 1, new QueueChannel(),
null);
this.aggregator.handleMessage(message);
}
@Test
public void testAdditionalMessageAfterCompletion()
throws InterruptedException {
QueueChannel replyChannel = new QueueChannel();
Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null);
Message<?> message4 = createMessage(7, "ABC", 3, 3, replyChannel, null);
CountDownLatch latch = new CountDownLatch(4);
this.aggregator.setReleaseStrategy(new SequenceSizeReleaseStrategy());
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message1, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message2, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message3, latch));
this.taskExecutor.execute(new AggregatorTestTask(this.aggregator,
message4, latch));
assertTrue(latch.await(10, TimeUnit.SECONDS));
Message<?> reply = replyChannel.receive(10000);
assertNotNull("A message should be aggregated", reply);
assertThat(reply.getPayload(), is(105));
}
private static Message<?> createMessage(Object payload,
Object correlationId, int sequenceSize, int sequenceNumber,
MessageChannel replyChannel, String predefinedId) {
MessageBuilder<Object> builder = MessageBuilder.withPayload(payload)
.setCorrelationId(correlationId).setSequenceSize(sequenceSize)
.setSequenceNumber(sequenceNumber)
.setReplyChannel(replyChannel);
if (predefinedId != null) {
builder.setHeader(MessageHeaders.ID, predefinedId);
}
return builder.build();
}
private static class AggregatorTestTask implements Runnable {
private final MessageHandler aggregator;
private final Message<?> message;
private Exception exception;
private final CountDownLatch latch;
AggregatorTestTask(MessageHandler aggregator, Message<?> message,
CountDownLatch latch) {
this.aggregator = aggregator;
this.message = message;
this.latch = latch;
}
public Exception getException() {
return this.exception;
}
@Override
public void run() {
try {
this.aggregator.handleMessage(message);
}
catch (Exception e) {
this.exception = e;
}
finally {
this.latch.countDown();
}
}
}
private class MultiplyingProcessor implements MessageGroupProcessor {
MultiplyingProcessor() {
super();
}
@Override
public Object processMessageGroup(MessageGroup group) {
Integer product = 1;
for (Message<?> message : group.getMessages()) {
product *= (Integer) message.getPayload();
}
return product;
}
}
@SuppressWarnings("unused")
private class NullReturningMessageProcessor implements MessageGroupProcessor {
NullReturningMessageProcessor() {
super();
}
@Override
public Object processMessageGroup(MessageGroup group) {
return null;
}
}
}