/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.endpoint; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; import java.util.Date; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import org.springframework.beans.factory.BeanFactory; import org.springframework.integration.MessageRejectedException; import org.springframework.integration.support.MessagingExceptionWrapper; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.PollableChannel; import org.springframework.messaging.support.GenericMessage; import org.springframework.scheduling.Trigger; import org.springframework.scheduling.TriggerContext; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.ErrorHandler; /** * @author Iwein Fuld * @author Mark Fisher */ @SuppressWarnings("unchecked") public class PollingConsumerEndpointTests { private PollingConsumer endpoint; private final TestTrigger trigger = new TestTrigger(); private final TestConsumer consumer = new TestConsumer(); @SuppressWarnings("rawtypes") private final Message message = new GenericMessage<String>("test"); @SuppressWarnings("rawtypes") private final Message badMessage = new GenericMessage<String>("bad"); private final TestErrorHandler errorHandler = new TestErrorHandler(); private final PollableChannel channelMock = Mockito.mock(PollableChannel.class); private final ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); @Before public void init() throws Exception { consumer.counter.set(0); trigger.reset(); endpoint = new PollingConsumer(channelMock, consumer); taskScheduler.setPoolSize(5); endpoint.setErrorHandler(errorHandler); endpoint.setTaskScheduler(taskScheduler); endpoint.setTrigger(trigger); endpoint.setBeanFactory(mock(BeanFactory.class)); endpoint.setReceiveTimeout(-1); endpoint.afterPropertiesSet(); taskScheduler.afterPropertiesSet(); Mockito.reset(channelMock); } @After public void stop() throws Exception { taskScheduler.destroy(); } @Test public void singleMessage() { Mockito.when(channelMock.receive()).thenReturn(message); endpoint.setMaxMessagesPerPoll(1); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(1, consumer.counter.get()); } @Test public void multipleMessages() { Mockito.when(channelMock.receive()).thenReturn(message, message, message, message, message); endpoint.setMaxMessagesPerPoll(5); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(5, consumer.counter.get()); } @Test public void multipleMessages_underrun() { Mockito.when(channelMock.receive()).thenReturn(message, message, message, message, message, null); endpoint.setMaxMessagesPerPoll(6); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(5, consumer.counter.get()); } @Test public void heavierLoadTest() throws Exception { for (int i = 0; i < 1000; i++) { this.init(); this.multipleMessages(); this.stop(); } } @Test(expected = MessageRejectedException.class) public void rejectedMessage() throws Throwable { Mockito.when(channelMock.receive()).thenReturn(badMessage); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(1, consumer.counter.get()); errorHandler.throwLastErrorIfAvailable(); } @Test(expected = MessageRejectedException.class) public void droppedMessage_onePerPoll() throws Throwable { Mockito.when(channelMock.receive()).thenReturn(badMessage); endpoint.setMaxMessagesPerPoll(10); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(1, consumer.counter.get()); errorHandler.throwLastErrorIfAvailable(); } @Test public void blockingSourceTimedOut() { // we don't need to await the timeout, returning null suffices Mockito.when(channelMock.receive()).thenReturn(null); endpoint.setReceiveTimeout(1); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(0, consumer.counter.get()); } @Test public void blockingSourceNotTimedOut() { Mockito.when(channelMock.receive(Mockito.eq(1L))).thenReturn(message); endpoint.setReceiveTimeout(1); endpoint.setMaxMessagesPerPoll(1); endpoint.start(); trigger.await(); endpoint.stop(); assertEquals(1, consumer.counter.get()); } private static class TestConsumer implements MessageHandler { private volatile AtomicInteger counter = new AtomicInteger(); TestConsumer() { super(); } @Override public void handleMessage(Message<?> message) { this.counter.incrementAndGet(); if ("bad".equals(message.getPayload().toString())) { throw new MessageRejectedException(message, "intentional test failure"); } } } private static class TestTrigger implements Trigger { private final AtomicBoolean hasRun = new AtomicBoolean(); private volatile CountDownLatch latch = new CountDownLatch(1); TestTrigger() { super(); } @Override public Date nextExecutionTime(TriggerContext triggerContext) { if (!this.hasRun.getAndSet(true)) { return new Date(); } this.latch.countDown(); return null; } public void reset() { this.latch = new CountDownLatch(1); this.hasRun.set(false); } public void await() { try { this.latch.await(5000, TimeUnit.MILLISECONDS); if (latch.getCount() != 0) { throw new RuntimeException("test latch.await() did not count down"); } } catch (InterruptedException e) { throw new RuntimeException("test latch.await() interrupted"); } } } private static class TestErrorHandler implements ErrorHandler { private volatile Throwable lastError; TestErrorHandler() { super(); } @Override public void handleError(Throwable t) { this.lastError = t; } public void throwLastErrorIfAvailable() throws Throwable { if (this.lastError instanceof MessagingExceptionWrapper) { this.lastError = this.lastError.getCause(); } Throwable t = this.lastError; this.lastError = null; throw t; } } }