/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.ip.tcp;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ServerSocketFactory;
import javax.net.SocketFactory;
import org.junit.Test;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.handler.ServiceActivatingHandler;
import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.AbstractServerConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpNetServerConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpNioServerConnectionFactory;
import org.springframework.integration.ip.util.TestingUtilities;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
/**
* @author Gary Russell
* @since 2.0
*/
public class TcpInboundGatewayTests {
@Test
public void testNetSingle() throws Exception {
AbstractServerConnectionFactory scf = new TcpNetServerConnectionFactory(0);
scf.setSingleUse(true);
TcpInboundGateway gateway = new TcpInboundGateway();
gateway.setConnectionFactory(scf);
gateway.setBeanFactory(mock(BeanFactory.class));
scf.start();
TestingUtilities.waitListening(scf, 20000L);
int port = scf.getPort();
final QueueChannel channel = new QueueChannel();
gateway.setRequestChannel(channel);
ServiceActivatingHandler handler = new ServiceActivatingHandler(new Service());
handler.setChannelResolver(channelName -> channel);
Socket socket1 = SocketFactory.getDefault().createSocket("localhost", port);
socket1.getOutputStream().write("Test1\r\n".getBytes());
Socket socket2 = SocketFactory.getDefault().createSocket("localhost", port);
socket2.getOutputStream().write("Test2\r\n".getBytes());
handler.handleMessage(channel.receive(10000));
handler.handleMessage(channel.receive(10000));
byte[] bytes = new byte[12];
readFully(socket1.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket2.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
}
@Test
public void testNetNotSingle() throws Exception {
AbstractServerConnectionFactory scf = new TcpNetServerConnectionFactory(0);
scf.setSingleUse(false);
TcpInboundGateway gateway = new TcpInboundGateway();
gateway.setConnectionFactory(scf);
scf.start();
TestingUtilities.waitListening(scf, 20000L);
int port = scf.getPort();
final QueueChannel channel = new QueueChannel();
gateway.setRequestChannel(channel);
gateway.setBeanFactory(mock(BeanFactory.class));
ServiceActivatingHandler handler = new ServiceActivatingHandler(new Service());
Socket socket = SocketFactory.getDefault().createSocket("localhost", port);
socket.getOutputStream().write("Test1\r\n".getBytes());
socket.getOutputStream().write("Test2\r\n".getBytes());
handler.handleMessage(channel.receive(10000));
handler.handleMessage(channel.receive(10000));
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
}
@Test
public void testNetClientMode() throws Exception {
final AtomicInteger port = new AtomicInteger();
final CountDownLatch latch1 = new CountDownLatch(1);
final CountDownLatch latch2 = new CountDownLatch(1);
final CountDownLatch latch3 = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
Executors.newSingleThreadExecutor().execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
port.set(server.getLocalPort());
latch1.countDown();
Socket socket = server.accept();
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
latch2.await();
socket.close();
server.close();
done.set(true);
latch3.countDown();
}
catch (Exception e) {
if (!done.get()) {
e.printStackTrace();
}
}
});
assertTrue(latch1.await(10, TimeUnit.SECONDS));
AbstractClientConnectionFactory ccf = new TcpNetClientConnectionFactory("localhost", port.get());
ccf.setSingleUse(false);
TcpInboundGateway gateway = new TcpInboundGateway();
gateway.setConnectionFactory(ccf);
final QueueChannel channel = new QueueChannel();
gateway.setRequestChannel(channel);
gateway.setClientMode(true);
gateway.setRetryInterval(10000);
gateway.setBeanFactory(mock(BeanFactory.class));
gateway.afterPropertiesSet();
ServiceActivatingHandler handler = new ServiceActivatingHandler(new Service());
ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
taskScheduler.setPoolSize(1);
taskScheduler.initialize();
gateway.setTaskScheduler(taskScheduler);
gateway.start();
Message<?> message = channel.receive(10000);
assertNotNull(message);
handler.handleMessage(message);
message = channel.receive(10000);
assertNotNull(message);
handler.handleMessage(message);
latch2.countDown();
assertTrue(latch3.await(10, TimeUnit.SECONDS));
assertTrue(done.get());
gateway.stop();
}
@Test
public void testNioSingle() throws Exception {
AbstractServerConnectionFactory scf = new TcpNioServerConnectionFactory(0);
scf.setSingleUse(true);
TcpInboundGateway gateway = new TcpInboundGateway();
gateway.setConnectionFactory(scf);
scf.start();
TestingUtilities.waitListening(scf, 20000L);
int port = scf.getPort();
final QueueChannel channel = new QueueChannel();
gateway.setRequestChannel(channel);
gateway.setBeanFactory(mock(BeanFactory.class));
ServiceActivatingHandler handler = new ServiceActivatingHandler(new Service());
handler.setChannelResolver(channelName -> channel);
Socket socket1 = SocketFactory.getDefault().createSocket("localhost", port);
socket1.getOutputStream().write("Test1\r\n".getBytes());
Socket socket2 = SocketFactory.getDefault().createSocket("localhost", port);
socket2.getOutputStream().write("Test2\r\n".getBytes());
handler.handleMessage(channel.receive(10000));
handler.handleMessage(channel.receive(10000));
byte[] bytes = new byte[12];
readFully(socket1.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket2.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
}
@Test
public void testNioNotSingle() throws Exception {
AbstractServerConnectionFactory scf = new TcpNioServerConnectionFactory(0);
scf.setSingleUse(false);
TcpInboundGateway gateway = new TcpInboundGateway();
gateway.setConnectionFactory(scf);
scf.start();
TestingUtilities.waitListening(scf, 20000L);
int port = scf.getPort();
final QueueChannel channel = new QueueChannel();
gateway.setRequestChannel(channel);
gateway.setBeanFactory(mock(BeanFactory.class));
ServiceActivatingHandler handler = new ServiceActivatingHandler(new Service());
Socket socket = SocketFactory.getDefault().createSocket("localhost", port);
socket.getOutputStream().write("Test1\r\n".getBytes());
socket.getOutputStream().write("Test2\r\n".getBytes());
handler.handleMessage(channel.receive(10000));
handler.handleMessage(channel.receive(10000));
Set<String> results = new HashSet<String>();
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
results.add(new String(bytes));
readFully(socket.getInputStream(), bytes);
results.add(new String(bytes));
assertTrue(results.remove("Echo:Test1\r\n"));
assertTrue(results.remove("Echo:Test2\r\n"));
}
@Test
public void testErrorFlow() throws Exception {
AbstractServerConnectionFactory scf = new TcpNetServerConnectionFactory(0);
scf.setSingleUse(true);
TcpInboundGateway gateway = new TcpInboundGateway();
gateway.setConnectionFactory(scf);
SubscribableChannel errorChannel = new DirectChannel();
final String errorMessage = "An error occurred";
errorChannel.subscribe(message -> {
MessageChannel replyChannel = (MessageChannel) message.getHeaders().getReplyChannel();
replyChannel.send(new GenericMessage<String>(errorMessage));
});
gateway.setErrorChannel(errorChannel);
scf.start();
TestingUtilities.waitListening(scf, 20000L);
int port = scf.getPort();
final SubscribableChannel channel = new DirectChannel();
gateway.setRequestChannel(channel);
gateway.setBeanFactory(mock(BeanFactory.class));
ServiceActivatingHandler handler = new ServiceActivatingHandler(new FailingService());
channel.subscribe(handler);
Socket socket1 = SocketFactory.getDefault().createSocket("localhost", port);
socket1.getOutputStream().write("Test1\r\n".getBytes());
Socket socket2 = SocketFactory.getDefault().createSocket("localhost", port);
socket2.getOutputStream().write("Test2\r\n".getBytes());
byte[] bytes = new byte[errorMessage.length() + 2];
readFully(socket1.getInputStream(), bytes);
assertEquals(errorMessage + "\r\n", new String(bytes));
readFully(socket2.getInputStream(), bytes);
assertEquals(errorMessage + "\r\n", new String(bytes));
}
private void readFully(InputStream is, byte[] buff) throws IOException {
for (int i = 0; i < buff.length; i++) {
buff[i] = (byte) is.read();
}
}
private class Service {
@SuppressWarnings("unused")
public String serviceMethod(byte[] bytes) {
return "Echo:" + new String(bytes);
}
}
private class FailingService {
@SuppressWarnings("unused")
public String serviceMethod(byte[] bytes) {
throw new RuntimeException("Planned Failure For Tests");
}
}
}