/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.stream; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Date; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.springframework.beans.factory.BeanFactory; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.endpoint.PollingConsumer; import org.springframework.messaging.support.GenericMessage; import org.springframework.scheduling.Trigger; import org.springframework.scheduling.TriggerContext; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; /** * @author Mark Fisher */ public class ByteStreamWritingMessageHandlerTests { private ByteArrayOutputStream stream; private ByteStreamWritingMessageHandler handler; private QueueChannel channel; private PollingConsumer endpoint; private final TestTrigger trigger = new TestTrigger(); private ThreadPoolTaskScheduler scheduler; @Before public void initialize() { stream = new ByteArrayOutputStream(); handler = new ByteStreamWritingMessageHandler(stream); this.channel = new QueueChannel(10); this.endpoint = new PollingConsumer(channel, handler); scheduler = new ThreadPoolTaskScheduler(); this.endpoint.setTaskScheduler(scheduler); scheduler.afterPropertiesSet(); trigger.reset(); endpoint.setTrigger(trigger); endpoint.setBeanFactory(mock(BeanFactory.class)); } @After public void stop() throws Exception { scheduler.destroy(); } @Test public void singleByteArray() { handler.handleMessage(new GenericMessage<byte[]>(new byte[] {1, 2, 3})); byte[] result = stream.toByteArray(); assertEquals(3, result.length); assertEquals(1, result[0]); assertEquals(2, result[1]); assertEquals(3, result[2]); } @Test public void singleString() { handler.handleMessage(new GenericMessage<String>("foo")); byte[] result = stream.toByteArray(); assertEquals(3, result.length); assertEquals("foo", new String(result)); } @Test public void maxMessagesPerTaskSameAsMessageCount() { endpoint.setTrigger(trigger); endpoint.setMaxMessagesPerPoll(3); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result = stream.toByteArray(); assertEquals(9, result.length); assertEquals(1, result[0]); assertEquals(9, result[8]); } @Test public void maxMessagesPerTaskLessThanMessageCount() { endpoint.setTrigger(trigger); endpoint.setMaxMessagesPerPoll(2); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result = stream.toByteArray(); assertEquals(6, result.length); assertEquals(1, result[0]); } @Test public void maxMessagesPerTaskExceedsMessageCount() { endpoint.setTrigger(trigger); endpoint.setMaxMessagesPerPoll(5); endpoint.setReceiveTimeout(0); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result = stream.toByteArray(); assertEquals(9, result.length); assertEquals(1, result[0]); } @Test public void testMaxMessagesLessThanMessageCountWithMultipleDispatches() { endpoint.setTrigger(trigger); endpoint.setMaxMessagesPerPoll(2); endpoint.setReceiveTimeout(0); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result1 = stream.toByteArray(); assertEquals(6, result1.length); assertEquals(1, result1[0]); trigger.reset(); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result2 = stream.toByteArray(); assertEquals(9, result2.length); assertEquals(1, result2[0]); assertEquals(7, result2[6]); } @Test public void testMaxMessagesExceedsMessageCountWithMultipleDispatches() { endpoint.setTrigger(trigger); endpoint.setMaxMessagesPerPoll(5); endpoint.setReceiveTimeout(0); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result1 = stream.toByteArray(); assertEquals(9, result1.length); assertEquals(1, result1[0]); trigger.reset(); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result2 = stream.toByteArray(); assertEquals(9, result2.length); assertEquals(1, result2[0]); } @Test public void testStreamResetBetweenDispatches() { endpoint.setMaxMessagesPerPoll(2); endpoint.setTrigger(trigger); endpoint.setReceiveTimeout(0); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result1 = stream.toByteArray(); assertEquals(6, result1.length); stream.reset(); trigger.reset(); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result2 = stream.toByteArray(); assertEquals(3, result2.length); assertEquals(7, result2[0]); } @Test public void testStreamWriteBetweenDispatches() throws IOException { endpoint.setTrigger(trigger); endpoint.setMaxMessagesPerPoll(2); endpoint.setReceiveTimeout(0); channel.send(new GenericMessage<byte[]>(new byte[] {1, 2, 3}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {4, 5, 6}), 0); channel.send(new GenericMessage<byte[]>(new byte[] {7, 8, 9}), 0); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result1 = stream.toByteArray(); assertEquals(6, result1.length); stream.write(new byte[] {123}); stream.flush(); trigger.reset(); endpoint.start(); trigger.await(); endpoint.stop(); byte[] result2 = stream.toByteArray(); assertEquals(10, result2.length); assertEquals(1, result2[0]); assertEquals(123, result2[6]); assertEquals(7, result2[7]); } private static class TestTrigger implements Trigger { private final AtomicBoolean hasRun = new AtomicBoolean(); private volatile CountDownLatch latch = new CountDownLatch(1); TestTrigger() { super(); } @Override public Date nextExecutionTime(TriggerContext triggerContext) { if (!hasRun.getAndSet(true)) { return new Date(); } this.latch.countDown(); return null; } public void reset() { this.latch = new CountDownLatch(1); this.hasRun.set(false); } public void await() { try { this.latch.await(3000, TimeUnit.MILLISECONDS); if (latch.getCount() != 0) { throw new RuntimeException("test timeout"); } } catch (InterruptedException e) { throw new RuntimeException("test latch.await() interrupted"); } } } }