/*
* Copyright 2015-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.xd.dirt.integration.rabbit;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.amqp.rabbit.connection.Connection;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.listener.SimpleMessageListenerContainer;
import org.springframework.amqp.utils.test.TestUtils;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.http.MediaType;
import org.springframework.test.web.client.MockRestServiceServer;
import org.springframework.web.client.RestTemplate;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.Consumer;
/**
*
* @author Gary Russell
*/
public class LocalizedQueueConnectionFactoryTests {
private final Map<String, ConnectionFactory> cfs = new HashMap<>();
private final Map<String, Connection> connections = new HashMap<>();
private final Map<String, Channel> channels = new HashMap<>();
private final Map<String, Consumer> consumers = new HashMap<>();
private final Map<String, String> consumerTags = new HashMap<>();
private final CountDownLatch latch = new CountDownLatch(2);
@SuppressWarnings("unchecked")
@Test
public void testFailOver() throws Exception {
ConnectionFactory defaultConnectionFactory = mockCF("localhost:1234");
String rabbit1 = "localhost:1235";
String rabbit2 = "localhost:1236";
String[] addresses = new String[] { rabbit1, rabbit2 };
String[] adminAddresses = new String[] { "http://localhost:11235", "http://localhost:11236" };
String[] nodes = new String[] { "rabbit@foo", "rabbit@bar" };
String vhost = "/";
String username = "guest";
String password = "guest";
final AtomicBoolean firstServer = new AtomicBoolean(true);
LocalizedQueueConnectionFactory lqcf = new LocalizedQueueConnectionFactory(defaultConnectionFactory, addresses,
adminAddresses, nodes, vhost, username, password, false, null, null,
null, null, null) {
private final String[] nodes = new String[] { "rabbit@foo", "rabbit@bar" };
@Override
protected RestTemplate createRestTemplate(String adminUri) {
return doCreateRestTemplate(adminUri, firstServer.get() ? nodes[0] : nodes[1]);
}
@Override
protected ConnectionFactory createConnectionFactory(String address) throws Exception {
return mockCF(address);
}
};
Log logger = spy(TestUtils.getPropertyValue(lqcf, "logger", Log.class));
new DirectFieldAccessor(lqcf).setPropertyValue("logger", logger);
when(logger.isDebugEnabled()).thenReturn(true);
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(lqcf);
container.setQueueNames("q");
container.afterPropertiesSet();
container.start();
Channel channel = this.channels.get(rabbit1);
verify(channel).basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(),
anyBoolean(), anyMap(),
Matchers.any(Consumer.class));
verify(logger, atLeast(1)).debug(captor.capture());
assertTrue(assertLog(captor.getAllValues(), "Queue: q is on node: rabbit@foo at: localhost:1235"));
// Fail rabbit1 and verify the container switches to rabbit2
firstServer.set(false);
when(channel.isOpen()).thenReturn(false);
when(this.connections.get(rabbit1).isOpen()).thenReturn(false);
this.consumers.get(rabbit1).handleCancel(consumerTags.get(rabbit1));
assertTrue(latch.await(10, TimeUnit.SECONDS));
channel = this.channels.get(rabbit2);
verify(channel).basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(),
anyBoolean(), anyMap(),
Matchers.any(Consumer.class));
container.stop();
verify(logger, atLeast(1)).debug(captor.capture());
assertTrue(assertLog(captor.getAllValues(), "Queue: q is on node: rabbit@bar at: localhost:1236"));
}
private boolean assertLog(List<String> logRows, String expected) {
for (String log : logRows) {
if (log.contains(expected)) {
return true;
}
}
return false;
}
private RestTemplate doCreateRestTemplate(String uri, String node) {
RestTemplate template = new RestTemplate();
MockRestServiceServer server = MockRestServiceServer.createServer(template);
server.expect(requestTo(uri + "/api/queues/%2F/q"))
.andRespond(withSuccess("{ \"node\" : \""
+ node
+ "\" }", MediaType.APPLICATION_JSON));
return template;
}
@SuppressWarnings("unchecked")
private ConnectionFactory mockCF(final String address) throws Exception {
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
Connection connection = mock(Connection.class);
Channel channel = mock(Channel.class);
when(connectionFactory.createConnection()).thenReturn(connection);
when(connection.createChannel(false)).thenReturn(channel);
when(connection.isOpen()).thenReturn(true);
when(channel.isOpen()).thenReturn(true);
doAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
String tag = UUID.randomUUID().toString();
consumers.put(address, (Consumer) invocation.getArguments()[6]);
consumerTags.put(address, tag);
latch.countDown();
return tag;
}
}).when(channel).basicConsume(anyString(), anyBoolean(), anyString(), anyBoolean(), anyBoolean(), anyMap(),
any(Consumer.class));
when(connectionFactory.getHost()).thenReturn(address);
this.cfs.put(address, connectionFactory);
this.connections.put(address, connection);
this.channels.put(address, channel);
return connectionFactory;
}
}