/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.channel.interceptor;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.support.ClassPathXmlApplicationContext;
import org.springframework.integration.channel.AbstractMessageChannel;
import org.springframework.integration.channel.ChannelInterceptorAware;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.endpoint.PollingConsumer;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.messaging.support.ExecutorChannelInterceptor;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.util.StringUtils;
/**
* @author Mark Fisher
* @author Oleg Zhurakousky
* @author Artem Bilan
*/
public class ChannelInterceptorTests {
private final QueueChannel channel = new QueueChannel();
@Test
public void testPreSendInterceptorReturnsMessage() {
PreSendReturnsMessageInterceptor interceptor = new PreSendReturnsMessageInterceptor();
channel.addInterceptor(interceptor);
channel.send(new GenericMessage<String>("test"));
Message<?> result = channel.receive(0);
assertNotNull(result);
assertEquals("test", result.getPayload());
assertEquals(1, result.getHeaders().get(PreSendReturnsMessageInterceptor.class.getSimpleName()));
assertTrue(interceptor.wasAfterCompletionInvoked());
}
@Test
public void testPreSendInterceptorReturnsNull() {
PreSendReturnsNullInterceptor interceptor = new PreSendReturnsNullInterceptor();
channel.addInterceptor(interceptor);
Message<?> message = new GenericMessage<String>("test");
channel.send(message);
assertEquals(1, interceptor.getCount());
assertTrue(channel.removeInterceptor(interceptor));
channel.send(new GenericMessage<String>("TEST"));
assertEquals(1, interceptor.getCount());
Message<?> result = channel.receive(0);
assertNotNull(result);
assertEquals("TEST", result.getPayload());
}
@Test
public void testPostSendInterceptorWithSentMessage() {
final AtomicBoolean invoked = new AtomicBoolean(false);
channel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
assertNotNull(message);
assertNotNull(channel);
assertSame(ChannelInterceptorTests.this.channel, channel);
assertTrue(sent);
invoked.set(true);
}
});
channel.send(new GenericMessage<String>("test"));
assertTrue(invoked.get());
}
@Test
public void testPostSendInterceptorWithUnsentMessage() {
final AtomicInteger invokedCounter = new AtomicInteger(0);
final AtomicInteger sentCounter = new AtomicInteger(0);
final QueueChannel singleItemChannel = new QueueChannel(1);
singleItemChannel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
assertNotNull(message);
assertNotNull(channel);
assertSame(singleItemChannel, channel);
if (sent) {
sentCounter.incrementAndGet();
}
invokedCounter.incrementAndGet();
}
});
assertEquals(0, invokedCounter.get());
assertEquals(0, sentCounter.get());
singleItemChannel.send(new GenericMessage<String>("test1"));
assertEquals(1, invokedCounter.get());
assertEquals(1, sentCounter.get());
singleItemChannel.send(new GenericMessage<String>("test2"), 0);
assertEquals(2, invokedCounter.get());
assertEquals(1, sentCounter.get());
assertNotNull(singleItemChannel.removeInterceptor(0));
singleItemChannel.send(new GenericMessage<String>("test2"), 0);
assertEquals(2, invokedCounter.get());
assertEquals(1, sentCounter.get());
}
@Test
public void afterCompletionWithSendException() {
final AbstractMessageChannel testChannel = new AbstractMessageChannel() {
@Override
protected boolean doSend(Message<?> message, long timeout) {
throw new RuntimeException("Simulated exception");
}
};
AfterCompletionTestInterceptor interceptor1 = new AfterCompletionTestInterceptor();
AfterCompletionTestInterceptor interceptor2 = new AfterCompletionTestInterceptor();
testChannel.addInterceptor(interceptor1);
testChannel.addInterceptor(interceptor2);
try {
testChannel.send(MessageBuilder.withPayload("test").build());
}
catch (Exception ex) {
assertEquals("Simulated exception", ex.getCause().getMessage());
}
assertTrue(interceptor1.wasAfterCompletionInvoked());
assertTrue(interceptor2.wasAfterCompletionInvoked());
}
@Test
public void afterCompletionWithPreSendException() {
AfterCompletionTestInterceptor interceptor1 = new AfterCompletionTestInterceptor();
AfterCompletionTestInterceptor interceptor2 = new AfterCompletionTestInterceptor();
interceptor2.setExceptionToRaise(new RuntimeException("Simulated exception"));
this.channel.addInterceptor(interceptor1);
this.channel.addInterceptor(interceptor2);
try {
this.channel.send(MessageBuilder.withPayload("test").build());
}
catch (Exception ex) {
assertEquals("Simulated exception", ex.getCause().getMessage());
}
assertTrue(interceptor1.wasAfterCompletionInvoked());
assertFalse(interceptor2.wasAfterCompletionInvoked());
}
@Test
public void testPreReceiveInterceptorReturnsTrue() {
PreReceiveReturnsTrueInterceptor interceptor = new PreReceiveReturnsTrueInterceptor();
channel.addInterceptor(interceptor);
Message<?> message = new GenericMessage<String>("test");
channel.send(message);
Message<?> result = channel.receive(0);
assertEquals(1, interceptor.getCounter().get());
assertNotNull(result);
assertTrue(interceptor.wasAfterCompletionInvoked());
}
@Test
public void testPreReceiveInterceptorReturnsFalse() {
channel.addInterceptor(new PreReceiveReturnsFalseInterceptor());
Message<?> message = new GenericMessage<String>("test");
channel.send(message);
Message<?> result = channel.receive(0);
assertEquals(1, PreReceiveReturnsFalseInterceptor.counter.get());
assertNull(result);
}
@Test
public void testPostReceiveInterceptor() {
final AtomicInteger invokedCount = new AtomicInteger();
final AtomicInteger messageCount = new AtomicInteger();
channel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public Message<?> postReceive(Message<?> message, MessageChannel channel) {
assertNotNull(channel);
assertSame(ChannelInterceptorTests.this.channel, channel);
if (message != null) {
messageCount.incrementAndGet();
}
invokedCount.incrementAndGet();
return message;
}
});
channel.receive(0);
assertEquals(1, invokedCount.get());
assertEquals(0, messageCount.get());
channel.send(new GenericMessage<String>("test"));
Message<?> result = channel.receive(0);
assertNotNull(result);
assertEquals(2, invokedCount.get());
assertEquals(1, messageCount.get());
}
@Test
public void afterCompletionWithReceiveException() {
PreReceiveReturnsTrueInterceptor interceptor1 = new PreReceiveReturnsTrueInterceptor();
PreReceiveReturnsTrueInterceptor interceptor2 = new PreReceiveReturnsTrueInterceptor();
interceptor2.setExceptionToRaise(new RuntimeException("Simulated exception"));
channel.addInterceptor(interceptor1);
channel.addInterceptor(interceptor2);
try {
channel.receive(0);
}
catch (Exception ex) {
assertEquals("Simulated exception", ex.getMessage());
}
assertTrue(interceptor1.wasAfterCompletionInvoked());
assertFalse(interceptor2.wasAfterCompletionInvoked());
}
@Test
public void testInterceptorBeanWithPNamespace() {
ConfigurableApplicationContext ac =
new ClassPathXmlApplicationContext("ChannelInterceptorTests-context.xml", ChannelInterceptorTests.class);
ChannelInterceptorAware channel = ac.getBean("input", AbstractMessageChannel.class);
List<ChannelInterceptor> interceptors = channel.getChannelInterceptors();
ChannelInterceptor channelInterceptor = interceptors.get(0);
assertThat(channelInterceptor, Matchers.instanceOf(PreSendReturnsMessageInterceptor.class));
String foo = ((PreSendReturnsMessageInterceptor) channelInterceptor).getFoo();
assertTrue(StringUtils.hasText(foo));
assertEquals("foo", foo);
ac.close();
}
@Test
public void testPollingConsumerWithExecutorInterceptor() throws InterruptedException {
TestUtils.TestApplicationContext testApplicationContext = TestUtils.createTestApplicationContext();
QueueChannel channel = new QueueChannel();
final CountDownLatch latch1 = new CountDownLatch(1);
final CountDownLatch latch2 = new CountDownLatch(2);
final List<Message<?>> messages = new ArrayList<>();
PollingConsumer consumer = new PollingConsumer(channel, message -> {
messages.add(message);
latch1.countDown();
latch2.countDown();
});
testApplicationContext.registerBean("consumer", consumer);
testApplicationContext.refresh();
channel.send(new GenericMessage<>("foo"));
assertTrue(latch1.await(10, TimeUnit.SECONDS));
channel.addInterceptor(new TestExecutorInterceptor());
channel.send(new GenericMessage<>("foo"));
assertTrue(latch2.await(10, TimeUnit.SECONDS));
assertEquals(2, messages.size());
assertEquals("foo", messages.get(0).getPayload());
assertEquals("FOO", messages.get(1).getPayload());
testApplicationContext.close();
}
public static class PreSendReturnsMessageInterceptor extends ChannelInterceptorAdapter {
private String foo;
private static AtomicInteger counter = new AtomicInteger();
private volatile boolean afterCompletionInvoked;
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
assertNotNull(message);
return MessageBuilder.fromMessage(message)
.setHeader(this.getClass().getSimpleName(), counter.incrementAndGet())
.build();
}
public String getFoo() {
return foo;
}
public void setFoo(String foo) {
this.foo = foo;
}
public boolean wasAfterCompletionInvoked() {
return this.afterCompletionInvoked;
}
@Override
public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
this.afterCompletionInvoked = true;
}
}
private static class PreSendReturnsNullInterceptor extends ChannelInterceptorAdapter {
private static AtomicInteger counter = new AtomicInteger();
PreSendReturnsNullInterceptor() {
super();
}
protected int getCount() {
return counter.get();
}
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
assertNotNull(message);
counter.incrementAndGet();
return null;
}
}
private static class AfterCompletionTestInterceptor extends ChannelInterceptorAdapter {
private final AtomicInteger counter = new AtomicInteger();
private volatile boolean afterCompletionInvoked;
private RuntimeException exceptionToRaise;
AfterCompletionTestInterceptor() {
super();
}
public void setExceptionToRaise(RuntimeException exception) {
this.exceptionToRaise = exception;
}
@SuppressWarnings("unused")
public AtomicInteger getCounter() {
return this.counter;
}
public boolean wasAfterCompletionInvoked() {
return this.afterCompletionInvoked;
}
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
assertNotNull(message);
counter.incrementAndGet();
if (this.exceptionToRaise != null) {
throw this.exceptionToRaise;
}
return message;
}
@Override
public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
this.afterCompletionInvoked = true;
}
}
private static class PreReceiveReturnsTrueInterceptor extends ChannelInterceptorAdapter {
private final AtomicInteger counter = new AtomicInteger();
private volatile boolean afterCompletionInvoked;
private RuntimeException exceptionToRaise;
PreReceiveReturnsTrueInterceptor() {
super();
}
public void setExceptionToRaise(RuntimeException exception) {
this.exceptionToRaise = exception;
}
public AtomicInteger getCounter() {
return this.counter;
}
@Override
public boolean preReceive(MessageChannel channel) {
counter.incrementAndGet();
if (this.exceptionToRaise != null) {
throw this.exceptionToRaise;
}
return true;
}
public boolean wasAfterCompletionInvoked() {
return this.afterCompletionInvoked;
}
@Override
public void afterReceiveCompletion(Message<?> message, MessageChannel channel, Exception ex) {
this.afterCompletionInvoked = true;
}
}
private static class PreReceiveReturnsFalseInterceptor extends ChannelInterceptorAdapter {
private static AtomicInteger counter = new AtomicInteger();
PreReceiveReturnsFalseInterceptor() {
super();
}
@Override
public boolean preReceive(MessageChannel channel) {
counter.incrementAndGet();
return false;
}
}
private static class TestExecutorInterceptor extends ChannelInterceptorAdapter
implements ExecutorChannelInterceptor {
TestExecutorInterceptor() {
super();
}
@Override
public Message<?> beforeHandle(Message<?> message, MessageChannel channel, MessageHandler handler) {
return MessageBuilder.withPayload(((String) message.getPayload()).toUpperCase())
.copyHeaders(message.getHeaders())
.build();
}
@Override
public void afterMessageHandled(Message<?> message, MessageChannel channel, MessageHandler handler,
Exception ex) {
}
}
}