/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.ip.tcp.connection;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.contains;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.integration.channel.NullChannel;
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.ip.config.TcpConnectionFactoryFactoryBean;
import org.springframework.integration.ip.event.IpIntegrationEvent;
import org.springframework.integration.ip.tcp.TcpReceivingChannelAdapter;
import org.springframework.integration.test.support.LogAdjustingTestSupport;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
/**
* @author Gary Russell
* @author Artem Bilan
* @since 3.0
*
*/
public class ConnectionFactoryTests extends LogAdjustingTestSupport {
@Test
public void factoryBeanTests() {
TcpConnectionFactoryFactoryBean fb = new TcpConnectionFactoryFactoryBean("client");
assertEquals(AbstractClientConnectionFactory.class, fb.getObjectType());
fb = new TcpConnectionFactoryFactoryBean("server");
assertEquals(AbstractServerConnectionFactory.class, fb.getObjectType());
fb = new TcpConnectionFactoryFactoryBean();
assertEquals(AbstractConnectionFactory.class, fb.getObjectType());
}
@Test
public void testObtainConnectionIdsNet() throws Exception {
TcpNetServerConnectionFactory serverFactory = new TcpNetServerConnectionFactory(0);
testObtainConnectionIds(serverFactory);
}
@Test
public void testObtainConnectionIdsNio() throws Exception {
TcpNioServerConnectionFactory serverFactory = new TcpNioServerConnectionFactory(0);
testObtainConnectionIds(serverFactory);
}
public void testObtainConnectionIds(AbstractServerConnectionFactory serverFactory) throws Exception {
final List<IpIntegrationEvent> events =
Collections.synchronizedList(new ArrayList<IpIntegrationEvent>());
int expectedEvents = serverFactory instanceof TcpNetServerConnectionFactory
? 7 // Listening, + OPEN, CLOSE, EXCEPTION for each side
: 5; // Listening, + OPEN, CLOSE (but we *might* get exceptions, depending on timing).
final CountDownLatch serverListeningLatch = new CountDownLatch(1);
final CountDownLatch eventLatch = new CountDownLatch(expectedEvents);
ApplicationEventPublisher publisher = new ApplicationEventPublisher() {
@Override
public void publishEvent(ApplicationEvent event) {
LogFactory.getLog(this.getClass()).trace("Received: " + event);
events.add((IpIntegrationEvent) event);
if (event instanceof TcpConnectionServerListeningEvent) {
serverListeningLatch.countDown();
}
eventLatch.countDown();
}
@Override
public void publishEvent(Object event) {
}
};
serverFactory.setBeanName("serverFactory");
serverFactory.setApplicationEventPublisher(publisher);
serverFactory = spy(serverFactory);
final CountDownLatch serverConnectionInitLatch = new CountDownLatch(1);
doAnswer(invocation -> {
Object result = invocation.callRealMethod();
serverConnectionInitLatch.countDown();
return result;
}).when(serverFactory).wrapConnection(any(TcpConnectionSupport.class));
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setPoolSize(10);
scheduler.afterPropertiesSet();
BeanFactory bf = mock(BeanFactory.class);
when(bf.containsBean(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME)).thenReturn(true);
when(bf.getBean(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME, TaskScheduler.class)).thenReturn(scheduler);
serverFactory.setBeanFactory(bf);
TcpReceivingChannelAdapter adapter = new TcpReceivingChannelAdapter();
adapter.setOutputChannel(new NullChannel());
adapter.setConnectionFactory(serverFactory);
adapter.start();
assertTrue("Listening event not received", serverListeningLatch.await(10, TimeUnit.SECONDS));
assertThat(events.get(0), instanceOf(TcpConnectionServerListeningEvent.class));
assertThat(((TcpConnectionServerListeningEvent) events.get(0)).getPort(), equalTo(serverFactory.getPort()));
int port = serverFactory.getPort();
TcpNetClientConnectionFactory clientFactory = new TcpNetClientConnectionFactory("localhost", port);
clientFactory.registerListener(message -> false);
clientFactory.setBeanName("clientFactory");
clientFactory.setApplicationEventPublisher(publisher);
clientFactory.start();
TcpConnectionSupport client = clientFactory.getConnection();
List<String> clients = clientFactory.getOpenConnectionIds();
assertEquals(1, clients.size());
assertTrue(clients.contains(client.getConnectionId()));
assertTrue("Server connection failed to register", serverConnectionInitLatch.await(10, TimeUnit.SECONDS));
List<String> servers = serverFactory.getOpenConnectionIds();
assertEquals(1, servers.size());
assertTrue(serverFactory.closeConnection(servers.get(0)));
servers = serverFactory.getOpenConnectionIds();
assertEquals(0, servers.size());
int n = 0;
clients = clientFactory.getOpenConnectionIds();
while (n++ < 100 && clients.size() > 0) {
Thread.sleep(100);
clients = clientFactory.getOpenConnectionIds();
}
assertEquals(0, clients.size());
assertTrue(eventLatch.await(10, TimeUnit.SECONDS));
assertThat("Expected at least " + expectedEvents + " events; got: " + events.size() + " : " + events,
events.size(), greaterThanOrEqualTo(expectedEvents));
FooEvent event = new FooEvent(client, "foo");
client.publishEvent(event);
assertThat("Expected at least " + expectedEvents + " events; got: " + events.size() + " : " + events,
events.size(), greaterThanOrEqualTo(expectedEvents + 1));
try {
event = new FooEvent(mock(TcpConnectionSupport.class), "foo");
client.publishEvent(event);
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertTrue("Can only publish events with this as the source".equals(e.getMessage()));
}
SocketAddress address = serverFactory.getServerSocketAddress();
if (address instanceof InetSocketAddress) {
InetSocketAddress inetAddress = (InetSocketAddress) address;
assertEquals(port, inetAddress.getPort());
}
serverFactory.stop();
scheduler.shutdown();
}
@Test
public void testEarlyCloseNet() throws Exception {
AbstractServerConnectionFactory factory = new TcpNetServerConnectionFactory(0);
testEarlyClose(factory, "serverSocket", " stopped before accept");
}
@Test
public void testEarlyCloseNio() throws Exception {
AbstractServerConnectionFactory factory = new TcpNioServerConnectionFactory(0);
testEarlyClose(factory, "serverChannel", " stopped before registering the server channel");
}
private void testEarlyClose(final AbstractServerConnectionFactory factory, String property,
String message) throws Exception {
factory.setApplicationEventPublisher(mock(ApplicationEventPublisher.class));
factory.setBeanName("foo");
factory.registerListener(mock(TcpListener.class));
factory.afterPropertiesSet();
Log logger = spy(TestUtils.getPropertyValue(factory, "logger", Log.class));
new DirectFieldAccessor(factory).setPropertyValue("logger", logger);
final CountDownLatch latch1 = new CountDownLatch(1);
final CountDownLatch latch2 = new CountDownLatch(1);
final CountDownLatch latch3 = new CountDownLatch(1);
when(logger.isInfoEnabled()).thenReturn(true);
when(logger.isDebugEnabled()).thenReturn(true);
doAnswer(invocation -> {
latch1.countDown();
// wait until the stop nulls the channel
latch2.await(10, TimeUnit.SECONDS);
return null;
}).when(logger).info(contains("Listening"));
doAnswer(invocation -> {
latch3.countDown();
return null;
}).when(logger).debug(contains(message));
factory.start();
assertTrue("missing info log", latch1.await(10, TimeUnit.SECONDS));
// stop on a different thread because it waits for the executor
Executors.newSingleThreadExecutor().execute(() -> factory.stop());
int n = 0;
DirectFieldAccessor accessor = new DirectFieldAccessor(factory);
while (n++ < 200 && accessor.getPropertyValue(property) != null) {
Thread.sleep(100);
}
assertTrue("Stop was not invoked in time", n < 200);
latch2.countDown();
assertTrue("missing debug log", latch3.await(10, TimeUnit.SECONDS));
String expected = "foo, port=" + factory.getPort() + message;
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
verify(logger, atLeast(1)).debug(captor.capture());
assertThat(captor.getAllValues(), hasItem(expected));
factory.stop();
}
@SuppressWarnings("serial")
private class FooEvent extends TcpConnectionOpenEvent {
FooEvent(TcpConnectionSupport connection, String connectionFactoryName) {
super(connection, connectionFactoryName);
}
}
}