/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.amqp.rabbit.listener; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import org.mockito.Mockito; import org.mockito.stubbing.Answer; import org.springframework.amqp.AmqpRejectAndDontRequeueException; import org.springframework.amqp.ImmediateAcknowledgeAmqpException; import org.springframework.amqp.core.MessageListener; import org.springframework.amqp.rabbit.connection.AbstractConnectionFactory; import org.springframework.amqp.rabbit.connection.CachingConnectionFactory; import org.springframework.amqp.rabbit.connection.SingleConnectionFactory; import org.springframework.amqp.rabbit.core.ChannelAwareMessageListener; import org.springframework.amqp.rabbit.core.RabbitTemplate; import org.springframework.amqp.rabbit.transaction.ListenerFailedRuleBasedTransactionAttribute; import org.springframework.amqp.rabbit.transaction.RabbitTransactionManager; import org.springframework.beans.DirectFieldAccessor; import org.springframework.transaction.TransactionDefinition; import org.springframework.transaction.TransactionException; import org.springframework.transaction.interceptor.DefaultTransactionAttribute; import org.springframework.transaction.interceptor.NoRollbackRuleAttribute; import org.springframework.transaction.interceptor.RollbackRuleAttribute; import org.springframework.transaction.interceptor.RuleBasedTransactionAttribute; import org.springframework.transaction.support.AbstractPlatformTransactionManager; import org.springframework.transaction.support.DefaultTransactionStatus; import com.rabbitmq.client.AMQP.BasicProperties; import com.rabbitmq.client.Channel; import com.rabbitmq.client.Connection; import com.rabbitmq.client.ConnectionFactory; import com.rabbitmq.client.Consumer; import com.rabbitmq.client.Envelope; /** * @author Gary Russell * @since 1.1.2 * */ public abstract class ExternalTxManagerTests { /** * Verifies that an up-stack RabbitTemplate uses the listener's * channel (MessageListener). */ @Test public void testMessageListener() throws Exception { ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); Connection mockConnection = mock(Connection.class); final Channel onlyChannel = mock(Channel.class); given(onlyChannel.isOpen()).willReturn(true); final CachingConnectionFactory cachingConnectionFactory = new CachingConnectionFactory(mockConnectionFactory); cachingConnectionFactory.setExecutor(mock(ExecutorService.class)); given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection); given(mockConnection.isOpen()).willReturn(true); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(onlyChannel, tooManyChannels)).given(mockConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(onlyChannel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final AtomicReference<CountDownLatch> commitLatch = new AtomicReference<>(new CountDownLatch(1)); willAnswer(invocation -> { commitLatch.get().countDown(); return null; }).given(onlyChannel).txCommit(); final AtomicReference<CountDownLatch> rollbackLatch = new AtomicReference<>(new CountDownLatch(1)); willAnswer(invocation -> { rollbackLatch.get().countDown(); return null; }).given(onlyChannel).txRollback(); willAnswer(invocation -> { return null; }).given(onlyChannel).basicAck(anyLong(), anyBoolean()); final CountDownLatch latch = new CountDownLatch(1); AbstractMessageListenerContainer container = createContainer(cachingConnectionFactory); container.setMessageListener((MessageListener) message -> { RabbitTemplate rabbitTemplate = new RabbitTemplate(cachingConnectionFactory); rabbitTemplate.setChannelTransacted(true); // should use same channel as container rabbitTemplate.convertAndSend("foo", "bar", "baz"); latch.countDown(); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setShutdownTimeout(100); DummyTxManager transactionManager = new DummyTxManager(); container.setTransactionManager(transactionManager); RuleBasedTransactionAttribute transactionAttribute = new ListenerFailedRuleBasedTransactionAttribute(); List<RollbackRuleAttribute> rollbackRules = Collections.singletonList(new NoRollbackRuleAttribute(IllegalStateException.class)); transactionAttribute.setRollbackRules(rollbackRules); container.setTransactionAttribute(transactionAttribute); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(mockConnection, times(1)).createChannel(); assertTrue(commitLatch.get().await(10, TimeUnit.SECONDS)); verify(onlyChannel).basicAck(anyLong(), anyBoolean()); verify(onlyChannel).txCommit(); verify(onlyChannel).basicPublish(anyString(), anyString(), anyBoolean(), any(BasicProperties.class), any(byte[].class)); // verify close() was never called on the channel DirectFieldAccessor dfa = new DirectFieldAccessor(cachingConnectionFactory); List<?> channels = (List<?>) dfa.getPropertyValue("cachedChannelsTransactional"); assertEquals(0, channels.size()); assertTrue(transactionManager.committed); transactionManager.committed = false; transactionManager.latch = new CountDownLatch(1); container.setMessageListener(m -> { throw new RuntimeException(); }); commitLatch.set(new CountDownLatch(1)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(transactionManager.latch.await(10, TimeUnit.SECONDS)); assertTrue(commitLatch.get().await(10, TimeUnit.SECONDS)); assertTrue(transactionManager.rolledBack); assertTrue(rollbackLatch.get().await(10, TimeUnit.SECONDS)); verify(onlyChannel).basicReject(anyLong(), anyBoolean()); verify(onlyChannel, times(1)).txRollback(); transactionManager.rolledBack = false; transactionManager.latch = new CountDownLatch(1); container.setMessageListener(m -> { throw new IllegalStateException(); }); commitLatch.set(new CountDownLatch(1)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(transactionManager.latch.await(10, TimeUnit.SECONDS)); assertTrue(commitLatch.get().await(10, TimeUnit.SECONDS)); assertTrue(transactionManager.committed); verify(onlyChannel, times(2)).basicAck(anyLong(), anyBoolean()); verify(onlyChannel, times(3)).txCommit(); // previous + reject commit for above + this one transactionManager.committed = false; transactionManager.latch = new CountDownLatch(1); container.setMessageListener(m -> { throw new AmqpRejectAndDontRequeueException("foo", new ImmediateAcknowledgeAmqpException("bar")); }); commitLatch.set(new CountDownLatch(1)); rollbackLatch.set(new CountDownLatch(1)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(transactionManager.latch.await(10, TimeUnit.SECONDS)); assertTrue(transactionManager.rolledBack); assertTrue(rollbackLatch.get().await(10, TimeUnit.SECONDS)); assertTrue(commitLatch.get().await(10, TimeUnit.SECONDS)); verify(onlyChannel, times(2)).basicReject(anyLong(), anyBoolean()); verify(onlyChannel, times(2)).txRollback(); transactionManager.rolledBack = false; transactionManager.latch = new CountDownLatch(1); container.setMessageListener(m -> { throw new ImmediateAcknowledgeAmqpException("foo"); }); commitLatch.set(new CountDownLatch(1)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(transactionManager.latch.await(10, TimeUnit.SECONDS)); assertTrue(commitLatch.get().await(10, TimeUnit.SECONDS)); assertTrue(transactionManager.committed); verify(onlyChannel, times(3)).basicAck(anyLong(), anyBoolean()); verify(onlyChannel, times(5)).txCommit(); container.stop(); } @Test public void testMessageListenerRollback() throws Exception { testMessageListenerRollbackGuts(true, TransactionDefinition.PROPAGATION_REQUIRED); } @Test public void testMessageListenerRollbackDontRequeue() throws Exception { testMessageListenerRollbackGuts(false, TransactionDefinition.PROPAGATION_REQUIRED); } @Test public void testMessageListenerRollbackNoBoundTransaction() throws Exception { testMessageListenerRollbackGuts(true, TransactionDefinition.PROPAGATION_NEVER); } @Test public void testMessageListenerRollbackDontRequeueNoBoundTransaction() throws Exception { testMessageListenerRollbackGuts(false, TransactionDefinition.PROPAGATION_NEVER); } /** * Verifies that the channel is rolled back after an exception. */ private void testMessageListenerRollbackGuts(boolean expectRequeue, int propagation) throws Exception { ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); Connection mockConnection = mock(Connection.class); final Channel channel = mock(Channel.class); given(channel.isOpen()).willReturn(true); final CachingConnectionFactory cachingConnectionFactory = new CachingConnectionFactory(mockConnectionFactory); cachingConnectionFactory.setExecutor(mock(ExecutorService.class)); given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection); given(mockConnection.isOpen()).willReturn(true); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(channel, tooManyChannels)).given(mockConnection).createChannel(); willAnswer(invocation -> channel).given(mockConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(channel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final CountDownLatch rollbackLatch = new CountDownLatch(1); willAnswer(invocation -> { rollbackLatch.countDown(); return null; }).given(channel).txRollback(); final CountDownLatch rejectLatch = new CountDownLatch(1); willAnswer(invocation -> { rejectLatch.countDown(); return null; }).given(channel).basicReject(anyLong(), anyBoolean()); willAnswer(invocation -> { rejectLatch.countDown(); return null; }).given(channel).basicNack(anyLong(), anyBoolean(), anyBoolean()); final CountDownLatch latch = new CountDownLatch(1); AbstractMessageListenerContainer container = createContainer(cachingConnectionFactory); container.setTransactionAttribute(new DefaultTransactionAttribute(propagation)); container.setMessageListener(message -> { latch.countDown(); throw expectRequeue ? new RuntimeException("force rollback") : new AmqpRejectAndDontRequeueException("force rollback"); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setShutdownTimeout(100); container.setTransactionManager(new DummyTxManager()); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(mockConnection, times(1)).createChannel(); assertTrue(rejectLatch.await(10, TimeUnit.SECONDS)); assertTrue(rollbackLatch.await(10, TimeUnit.SECONDS)); if (propagation != TransactionDefinition.PROPAGATION_NEVER) { verify(channel).basicReject(anyLong(), eq(expectRequeue)); } else { verify(channel).basicNack(anyLong(), eq(Boolean.TRUE), eq(expectRequeue)); } container.stop(); } @Test public void testMessageListenerCommit() throws Exception { testMessageListenerCommitGuts(TransactionDefinition.PROPAGATION_REQUIRED); } @Test public void testMessageListenerCommitNoBoundTransaction() throws Exception { testMessageListenerCommitGuts(TransactionDefinition.PROPAGATION_NEVER); } /** * Verifies that the channel is committed. */ private void testMessageListenerCommitGuts(int propagation) throws Exception { ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); Connection mockConnection = mock(Connection.class); final Channel channel = mock(Channel.class); given(channel.isOpen()).willReturn(true); final CachingConnectionFactory cachingConnectionFactory = new CachingConnectionFactory(mockConnectionFactory); cachingConnectionFactory.setExecutor(mock(ExecutorService.class)); given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection); given(mockConnection.isOpen()).willReturn(true); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(channel, tooManyChannels)).given(mockConnection).createChannel(); willAnswer(invocation -> channel).given(mockConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(channel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final CountDownLatch commitLatch = new CountDownLatch(1); willAnswer(invocation -> { commitLatch.countDown(); return null; }).given(channel).txCommit(); final CountDownLatch ackLatch = new CountDownLatch(1); willAnswer(invocation -> { ackLatch.countDown(); return null; }).given(channel).basicAck(anyLong(), anyBoolean()); final CountDownLatch latch = new CountDownLatch(1); AbstractMessageListenerContainer container = createContainer(cachingConnectionFactory); container.setTransactionAttribute(new DefaultTransactionAttribute(propagation)); container.setMessageListener(message -> { latch.countDown(); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setShutdownTimeout(100); container.setTransactionManager(new DummyTxManager()); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] { 0 }); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(mockConnection, times(1)).createChannel(); assertTrue(ackLatch.await(10, TimeUnit.SECONDS)); assertTrue(commitLatch.await(10, TimeUnit.SECONDS)); verify(channel).basicAck(anyLong(), anyBoolean()); container.stop(); } /** * Verifies that an up-stack RabbitTemplate does not use the listener's * channel when it has its own connection factory. */ @Test public void testMessageListenerTemplateUsesDifferentConnectionFactory() throws Exception { ConnectionFactory listenerConnectionFactory = mock(ConnectionFactory.class); ConnectionFactory templateConnectionFactory = mock(ConnectionFactory.class); Connection listenerConnection = mock(Connection.class); Connection templateConnection = mock(Connection.class); final Channel listenerChannel = mock(Channel.class); Channel templateChannel = mock(Channel.class); given(listenerChannel.isOpen()).willReturn(true); given(templateChannel.isOpen()).willReturn(true); final CachingConnectionFactory cachingConnectionFactory = new CachingConnectionFactory( listenerConnectionFactory); ExecutorService mockExec = mock(ExecutorService.class); cachingConnectionFactory.setExecutor(mockExec); final CachingConnectionFactory cachingTemplateConnectionFactory = new CachingConnectionFactory( templateConnectionFactory); cachingTemplateConnectionFactory.setExecutor(mockExec); given(listenerConnectionFactory.newConnection(any(ExecutorService.class), anyString())) .willReturn(listenerConnection); given(listenerConnection.isOpen()).willReturn(true); given(templateConnectionFactory.newConnection(any(ExecutorService.class), anyString())) .willReturn(templateConnection); given(templateConnection.isOpen()).willReturn(true); given(templateConnection.createChannel()).willReturn(templateChannel); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(listenerChannel, tooManyChannels)).given(listenerConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(listenerChannel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final CountDownLatch commitLatch = new CountDownLatch(2); willAnswer(invocation -> { commitLatch.countDown(); return null; }).given(listenerChannel).txCommit(); willAnswer(invocation -> { commitLatch.countDown(); return null; }).given(templateChannel).txCommit(); final CountDownLatch latch = new CountDownLatch(1); AbstractMessageListenerContainer container = createContainer(cachingConnectionFactory); container.setMessageListener((MessageListener) message -> { RabbitTemplate rabbitTemplate = new RabbitTemplate(cachingTemplateConnectionFactory); rabbitTemplate.setChannelTransacted(true); // should use same channel as container rabbitTemplate.convertAndSend("foo", "bar", "baz"); latch.countDown(); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setShutdownTimeout(100); container.setTransactionManager(new DummyTxManager()); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] {0}); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(listenerConnection, Mockito.times(1)).createChannel(); verify(templateConnection, Mockito.times(1)).createChannel(); assertTrue(commitLatch.await(10, TimeUnit.SECONDS)); verify(listenerChannel).txCommit(); verify(templateChannel).basicPublish(Mockito.anyString(), Mockito.anyString(), Mockito.anyBoolean(), Mockito.any(BasicProperties.class), Mockito.any(byte[].class)); verify(templateChannel).txCommit(); // verify close() was never called on the channel DirectFieldAccessor dfa = new DirectFieldAccessor(cachingConnectionFactory); List<?> channels = (List<?>) dfa.getPropertyValue("cachedChannelsTransactional"); assertEquals(0, channels.size()); container.stop(); } /** * Verifies that an up-stack RabbitTemplate uses the listener's * channel (ChannelAwareMessageListener). */ @Test public void testChannelAwareMessageListener() throws Exception { ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); Connection mockConnection = mock(Connection.class); final Channel onlyChannel = mock(Channel.class); given(onlyChannel.isOpen()).willReturn(true); final SingleConnectionFactory singleConnectionFactory = new SingleConnectionFactory(mockConnectionFactory); singleConnectionFactory.setExecutor(mock(ExecutorService.class)); given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection); given(mockConnection.isOpen()).willReturn(true); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(onlyChannel, tooManyChannels)).given(mockConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(onlyChannel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final CountDownLatch commitLatch = new CountDownLatch(1); willAnswer(invocation -> { commitLatch.countDown(); return null; }).given(onlyChannel).txCommit(); final CountDownLatch latch = new CountDownLatch(1); final AtomicReference<Channel> exposed = new AtomicReference<Channel>(); AbstractMessageListenerContainer container = createContainer(singleConnectionFactory); container.setMessageListener((ChannelAwareMessageListener) (message, channel) -> { exposed.set(channel); RabbitTemplate rabbitTemplate = new RabbitTemplate(singleConnectionFactory); rabbitTemplate.setChannelTransacted(true); // should use same channel as container rabbitTemplate.convertAndSend("foo", "bar", "baz"); latch.countDown(); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setShutdownTimeout(100); container.setTransactionManager(new DummyTxManager()); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] {0}); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(mockConnection, Mockito.times(1)).createChannel(); assertTrue(commitLatch.await(10, TimeUnit.SECONDS)); verify(onlyChannel).txCommit(); verify(onlyChannel).basicPublish(Mockito.anyString(), Mockito.anyString(), Mockito.anyBoolean(), Mockito.any(BasicProperties.class), Mockito.any(byte[].class)); // verify close() was never called on the channel verify(onlyChannel, Mockito.never()).close(); container.stop(); assertSame(onlyChannel, exposed.get()); } /** * Verifies that an up-stack RabbitTemplate uses the listener's * channel (ChannelAwareMessageListener). exposeListenerChannel=false * is ignored (ChannelAwareMessageListener). */ @Test public void testChannelAwareMessageListenerDontExpose() throws Exception { ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); Connection mockConnection = mock(Connection.class); final Channel onlyChannel = mock(Channel.class); given(onlyChannel.isOpen()).willReturn(true); final SingleConnectionFactory singleConnectionFactory = new SingleConnectionFactory(mockConnectionFactory); singleConnectionFactory.setExecutor(mock(ExecutorService.class)); given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection); given(mockConnection.isOpen()).willReturn(true); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(onlyChannel, tooManyChannels)).given(mockConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(onlyChannel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final CountDownLatch commitLatch = new CountDownLatch(1); willAnswer(invocation -> { commitLatch.countDown(); return null; }).given(onlyChannel).txCommit(); final CountDownLatch latch = new CountDownLatch(1); final AtomicReference<Channel> exposed = new AtomicReference<Channel>(); SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(singleConnectionFactory); container.setMessageListener((ChannelAwareMessageListener) (message, channel) -> { exposed.set(channel); RabbitTemplate rabbitTemplate = new RabbitTemplate(singleConnectionFactory); rabbitTemplate.setChannelTransacted(true); // should use same channel as container rabbitTemplate.convertAndSend("foo", "bar", "baz"); latch.countDown(); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setExposeListenerChannel(false); container.setShutdownTimeout(100); container.setTransactionManager(new DummyTxManager()); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] {0}); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(mockConnection, Mockito.times(1)).createChannel(); assertTrue(commitLatch.await(10, TimeUnit.SECONDS)); verify(onlyChannel).txCommit(); verify(onlyChannel).basicPublish(Mockito.anyString(), Mockito.anyString(), Mockito.anyBoolean(), Mockito.any(BasicProperties.class), Mockito.any(byte[].class)); // verify close() was never called on the channel verify(onlyChannel, Mockito.never()).close(); container.stop(); assertSame(onlyChannel, exposed.get()); } /** * Verifies the proper channel is bound when using a RabbitTransactionManager. * Previously, the wrong channel was bound. See AMQP-260. * @throws Exception */ @Test public void testMessageListenerWithRabbitTxManager() throws Exception { ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); Connection mockConnection = mock(Connection.class); final Channel onlyChannel = mock(Channel.class); given(onlyChannel.isOpen()).willReturn(true); final CachingConnectionFactory cachingConnectionFactory = new CachingConnectionFactory(mockConnectionFactory); cachingConnectionFactory.setExecutor(mock(ExecutorService.class)); given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection); given(mockConnection.isOpen()).willReturn(true); final AtomicReference<Exception> tooManyChannels = new AtomicReference<Exception>(); willAnswer(ensureOneChannelAnswer(onlyChannel, tooManyChannels)).given(mockConnection).createChannel(); final AtomicReference<Consumer> consumer = new AtomicReference<Consumer>(); final CountDownLatch consumerLatch = new CountDownLatch(1); willAnswer(invocation -> { consumer.set(invocation.getArgument(6)); consumerLatch.countDown(); return "consumerTag"; }).given(onlyChannel) .basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(), any(Consumer.class)); final CountDownLatch commitLatch = new CountDownLatch(1); willAnswer(invocation -> { commitLatch.countDown(); return null; }).given(onlyChannel).txCommit(); final CountDownLatch latch = new CountDownLatch(1); AbstractMessageListenerContainer container = createContainer(cachingConnectionFactory); container.setMessageListener((MessageListener) message -> { RabbitTemplate rabbitTemplate = new RabbitTemplate(cachingConnectionFactory); rabbitTemplate.setChannelTransacted(true); // should use same channel as container rabbitTemplate.convertAndSend("foo", "bar", "baz"); latch.countDown(); }); container.setQueueNames("queue"); container.setChannelTransacted(true); container.setShutdownTimeout(100); container.setTransactionManager(new RabbitTransactionManager(cachingConnectionFactory)); container.afterPropertiesSet(); container.start(); assertTrue(consumerLatch.await(10, TimeUnit.SECONDS)); consumer.get().handleDelivery("qux", new Envelope(1, false, "foo", "bar"), new BasicProperties(), new byte[] {0}); assertTrue(latch.await(10, TimeUnit.SECONDS)); Exception e = tooManyChannels.get(); if (e != null) { throw e; } verify(mockConnection, Mockito.times(1)).createChannel(); assertTrue(commitLatch.await(10, TimeUnit.SECONDS)); verify(onlyChannel).txCommit(); verify(onlyChannel).basicPublish(Mockito.anyString(), Mockito.anyString(), Mockito.anyBoolean(), Mockito.any(BasicProperties.class), Mockito.any(byte[].class)); // verify close() was never called on the channel DirectFieldAccessor dfa = new DirectFieldAccessor(cachingConnectionFactory); List<?> channels = (List<?>) dfa.getPropertyValue("cachedChannelsTransactional"); assertEquals(0, channels.size()); container.stop(); } private Answer<Channel> ensureOneChannelAnswer(final Channel onlyChannel, final AtomicReference<Exception> tooManyChannels) { final AtomicBoolean done = new AtomicBoolean(); return invocation -> { if (!done.get()) { done.set(true); return onlyChannel; } tooManyChannels.set(new Exception("More than one channel requested")); Channel channel = mock(Channel.class); given(channel.isOpen()).willReturn(true); return channel; }; } protected abstract AbstractMessageListenerContainer createContainer(AbstractConnectionFactory connectionFactory); @SuppressWarnings("serial") private static class DummyTxManager extends AbstractPlatformTransactionManager { private volatile boolean committed; private volatile boolean rolledBack; private volatile CountDownLatch latch = new CountDownLatch(1); @Override protected Object doGetTransaction() throws TransactionException { return new Object(); } @Override protected void doBegin(Object transaction, TransactionDefinition definition) throws TransactionException { } @Override protected void doCommit(DefaultTransactionStatus status) throws TransactionException { this.committed = true; this.latch.countDown(); } @Override protected void doRollback(DefaultTransactionStatus status) throws TransactionException { this.rolledBack = true; this.latch.countDown(); } } }