/* * Copyright 2016-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.handler; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.commons.logging.Log; import org.junit.Before; import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; import org.springframework.beans.factory.BeanFactory; import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.endpoint.EventDrivenConsumer; import org.springframework.integration.gateway.GatewayProxyFactoryBean; import org.springframework.integration.test.util.TestUtils; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandlingException; import org.springframework.messaging.MessagingException; import org.springframework.messaging.core.DestinationResolutionException; import org.springframework.messaging.support.GenericMessage; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.concurrent.SettableListenableFuture; /** * @author Gary Russell * @since 4.3 * */ public class AsyncHandlerTests { private final QueueChannel output = new QueueChannel(); private AbstractReplyProducingMessageHandler handler; private volatile CountDownLatch latch; private volatile int whichTest; private volatile Exception failedCallbackException; private volatile String failedCallbackMessage; private volatile CountDownLatch exceptionLatch = new CountDownLatch(1); @Before public void setup() { this.handler = new AbstractReplyProducingMessageHandler() { @Override protected Object handleRequestMessage(Message<?> requestMessage) { final SettableListenableFuture<String> future = new SettableListenableFuture<String>(); Executors.newSingleThreadExecutor().execute(() -> { try { latch.await(10, TimeUnit.SECONDS); switch (whichTest) { case 0: future.set("reply"); break; case 1: future.setException(new RuntimeException("foo")); break; case 2: future.setException(new MessagingException(requestMessage)); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); } }); return future; } }; this.handler.setAsync(true); this.handler.setOutputChannel(this.output); this.handler.setBeanFactory(mock(BeanFactory.class)); this.latch = new CountDownLatch(1); Log logger = spy(TestUtils.getPropertyValue(this.handler, "logger", Log.class)); new DirectFieldAccessor(this.handler).setPropertyValue("logger", logger); doAnswer(invocation -> { failedCallbackMessage = invocation.getArgument(0); failedCallbackException = invocation.getArgument(1); exceptionLatch.countDown(); return null; }).when(logger).error(anyString(), any(Throwable.class)); } @Test public void testGoodResult() { this.whichTest = 0; this.handler.handleMessage(new GenericMessage<String>("foo")); assertNull(this.output.receive(0)); this.latch.countDown(); Message<?> received = this.output.receive(10000); assertNotNull(received); assertEquals("reply", received.getPayload()); assertNull(this.failedCallbackException); } @Test public void testGoodResultWithReplyChannelHeader() { this.whichTest = 0; this.handler.setOutputChannel(null); QueueChannel replyChannel = new QueueChannel(); Message<?> message = MessageBuilder.withPayload("foo") .setReplyChannel(replyChannel) .build(); this.handler.handleMessage(message); assertNull(replyChannel.receive(0)); this.latch.countDown(); Message<?> received = replyChannel.receive(10000); assertNotNull(received); assertEquals("reply", received.getPayload()); assertNull(this.failedCallbackException); } @Test public void testGoodResultWithNoReplyChannelHeaderNoOutput() throws Exception { this.whichTest = 0; this.handler.setOutputChannel(null); QueueChannel errorChannel = new QueueChannel(); Message<String> message = MessageBuilder.withPayload("foo").setErrorChannel(errorChannel).build(); this.handler.handleMessage(message); assertNull(this.output.receive(0)); this.latch.countDown(); Message<?> errorMessage = errorChannel.receive(1000); assertNotNull(errorMessage); assertThat(errorMessage.getPayload(), instanceOf(DestinationResolutionException.class)); assertEquals("no output-channel or replyChannel header available", ((Throwable) errorMessage.getPayload()).getMessage()); assertNull(((MessagingException) errorMessage.getPayload()).getFailedMessage()); assertNotNull(this.failedCallbackException); assertThat(this.failedCallbackException.getMessage(), containsString("or replyChannel header")); } @Test public void testRuntimeException() { QueueChannel errorChannel = new QueueChannel(); Message<String> message = MessageBuilder.withPayload("foo") .setErrorChannel(errorChannel) .build(); this.handler.handleMessage(message); assertNull(this.output.receive(0)); this.whichTest = 1; this.latch.countDown(); Message<?> received = errorChannel.receive(10000); assertNotNull(received); assertThat(received.getPayload(), instanceOf(MessageHandlingException.class)); assertEquals("foo", ((Throwable) received.getPayload()).getCause().getMessage()); assertSame(message, ((MessagingException) received.getPayload()).getFailedMessage()); assertNull(this.failedCallbackException); } @Test public void testMessagingException() { QueueChannel errorChannel = new QueueChannel(); Message<String> message = MessageBuilder.withPayload("foo") .setErrorChannel(errorChannel) .build(); this.handler.handleMessage(message); assertNull(this.output.receive(0)); this.whichTest = 2; this.latch.countDown(); Message<?> received = errorChannel.receive(10000); assertNotNull(received); assertThat(received.getPayload(), instanceOf(MessagingException.class)); assertSame(message, ((MessagingException) received.getPayload()).getFailedMessage()); assertNull(this.failedCallbackException); } @Test public void testMessagingExceptionNoErrorChannel() throws Exception { Message<String> message = MessageBuilder.withPayload("foo") .build(); this.handler.handleMessage(message); assertNull(this.output.receive(0)); this.whichTest = 2; this.latch.countDown(); assertTrue(this.exceptionLatch.await(10, TimeUnit.SECONDS)); assertNotNull(this.failedCallbackException); assertThat(this.failedCallbackMessage, containsString("no 'errorChannel' header")); } @Test public void testGateway() throws Exception { this.whichTest = 0; GatewayProxyFactoryBean gpfb = new GatewayProxyFactoryBean(Foo.class); gpfb.setBeanFactory(mock(BeanFactory.class)); DirectChannel input = new DirectChannel(); gpfb.setDefaultRequestChannel(input); gpfb.setDefaultReplyTimeout(10000L); gpfb.afterPropertiesSet(); Foo foo = (Foo) gpfb.getObject(); this.handler.setOutputChannel(null); EventDrivenConsumer consumer = new EventDrivenConsumer(input, this.handler); consumer.afterPropertiesSet(); consumer.start(); this.latch.countDown(); String result = foo.exchange("foo"); assertEquals("reply", result); } @Test public void testGatewayWithException() throws Exception { this.whichTest = 0; GatewayProxyFactoryBean gpfb = new GatewayProxyFactoryBean(Foo.class); gpfb.setBeanFactory(mock(BeanFactory.class)); DirectChannel input = new DirectChannel(); gpfb.setDefaultRequestChannel(input); gpfb.setDefaultReplyTimeout(10000L); gpfb.afterPropertiesSet(); Foo foo = (Foo) gpfb.getObject(); this.handler.setOutputChannel(null); EventDrivenConsumer consumer = new EventDrivenConsumer(input, this.handler); consumer.afterPropertiesSet(); consumer.start(); this.latch.countDown(); try { foo.exchange("foo"); } catch (MessagingException e) { assertThat(e.getClass().getSimpleName(), equalTo("RuntimeException")); assertThat(e.getMessage(), equalTo("foo")); } } private interface Foo { String exchange(String payload); } }