/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.ip.tcp.connection; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.Matchers.anyOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import java.io.InputStream; import java.net.InetAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.nio.channels.ClosedChannelException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import javax.net.ServerSocketFactory; import javax.net.SocketFactory; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import javax.net.ssl.SSLServerSocket; import org.junit.Test; import org.mockito.Mockito; import org.springframework.integration.ip.tcp.serializer.ByteArrayCrLfSerializer; import org.springframework.integration.ip.util.TestingUtilities; import org.springframework.integration.test.util.TestUtils; import org.springframework.messaging.Message; import org.springframework.messaging.support.GenericMessage; /** * @author Gary Russell * @since 2.2 */ public class SocketSupportTests { @Test public void testNetClient() throws Exception { TcpSocketFactorySupport factorySupport = mock(TcpSocketFactorySupport.class); SocketFactory factory = Mockito.mock(SocketFactory.class); when(factorySupport.getSocketFactory()).thenReturn(factory); Socket socket = mock(Socket.class); InputStream is = mock(InputStream.class); when(is.read()).thenReturn(-1); when(socket.getInputStream()).thenReturn(is); InetAddress inetAddress = InetAddress.getLocalHost(); when(socket.getInetAddress()).thenReturn(inetAddress); when(factory.createSocket("x", 0)).thenReturn(socket); TcpSocketSupport socketSupport = Mockito.mock(TcpSocketSupport.class); TcpNetClientConnectionFactory connectionFactory = new TcpNetClientConnectionFactory("x", 0); connectionFactory.setTcpSocketFactorySupport(factorySupport); connectionFactory.setTcpSocketSupport(socketSupport); connectionFactory.start(); connectionFactory.getConnection(); verify(socketSupport).postProcessSocket(socket); connectionFactory.stop(); } @Test public void testNetServer() throws Exception { TcpSocketFactorySupport factorySupport = mock(TcpSocketFactorySupport.class); ServerSocketFactory factory = mock(ServerSocketFactory.class); when(factorySupport.getServerSocketFactory()).thenReturn(factory); Socket socket = mock(Socket.class); InputStream is = mock(InputStream.class); when(is.read()).thenReturn(-1); when(socket.getInputStream()).thenReturn(is); InetAddress inetAddress = InetAddress.getLocalHost(); when(socket.getInetAddress()).thenReturn(inetAddress); ServerSocket serverSocket = mock(ServerSocket.class); when(serverSocket.getInetAddress()).thenReturn(inetAddress); when(factory.createServerSocket(0, 5)).thenReturn(serverSocket); final CountDownLatch latch1 = new CountDownLatch(1); final CountDownLatch latch2 = new CountDownLatch(1); when(serverSocket.accept()).thenReturn(socket).then(invocation -> { latch1.countDown(); latch2.await(10, TimeUnit.SECONDS); return null; }); TcpSocketSupport socketSupport = mock(TcpSocketSupport.class); TcpNetServerConnectionFactory connectionFactory = new TcpNetServerConnectionFactory(0); connectionFactory.setTcpSocketFactorySupport(factorySupport); connectionFactory.setTcpSocketSupport(socketSupport); connectionFactory.registerListener(mock(TcpListener.class)); connectionFactory.start(); assertTrue(latch1.await(10, TimeUnit.SECONDS)); verify(socketSupport).postProcessServerSocket(serverSocket); verify(socketSupport).postProcessSocket(socket); latch2.countDown(); connectionFactory.stop(); } @Test public void testNioClientAndServer() throws Exception { TcpNioServerConnectionFactory serverConnectionFactory = new TcpNioServerConnectionFactory(0); serverConnectionFactory.registerListener(message -> false); final AtomicInteger ppSocketCountServer = new AtomicInteger(); final AtomicInteger ppServerSocketCountServer = new AtomicInteger(); final CountDownLatch latch = new CountDownLatch(1); TcpSocketSupport serverSocketSupport = new TcpSocketSupport() { @Override public void postProcessSocket(Socket socket) { ppSocketCountServer.incrementAndGet(); latch.countDown(); } @Override public void postProcessServerSocket(ServerSocket serverSocket) { ppServerSocketCountServer.incrementAndGet(); } }; serverConnectionFactory.setTcpSocketSupport(serverSocketSupport); serverConnectionFactory.start(); TestingUtilities.waitListening(serverConnectionFactory, null); TcpNioClientConnectionFactory clientConnectionFactory = new TcpNioClientConnectionFactory("localhost", serverConnectionFactory.getPort()); final AtomicInteger ppSocketCountClient = new AtomicInteger(); final AtomicInteger ppServerSocketCountClient = new AtomicInteger(); TcpSocketSupport clientSocketSupport = new TcpSocketSupport() { @Override public void postProcessSocket(Socket socket) { ppSocketCountClient.incrementAndGet(); } @Override public void postProcessServerSocket(ServerSocket serverSocket) { ppServerSocketCountClient.incrementAndGet(); } }; clientConnectionFactory.setTcpSocketSupport(clientSocketSupport); clientConnectionFactory.start(); clientConnectionFactory.getConnection().send(new GenericMessage<String>("Hello, world!")); assertTrue(latch.await(10, TimeUnit.SECONDS)); assertEquals(0, ppServerSocketCountClient.get()); assertEquals(1, ppSocketCountClient.get()); assertEquals(1, ppServerSocketCountServer.get()); assertEquals(1, ppSocketCountServer.get()); clientConnectionFactory.stop(); serverConnectionFactory.stop(); } /* $ keytool -genkeypair -alias sitestcertkey -keyalg RSA -validity 36500 -keystore src/test/resources/test.ks Enter keystore password: secret Re-enter new password: secret What is your first and last name? [Unknown]: Spring Integration What is the name of your organizational unit? [Unknown]: SpringSource What is the name of your organization? [Unknown]: VMware What is the name of your City or Locality? [Unknown]: Palo Alto What is the name of your State or Province? [Unknown]: CA What is the two-letter country code for this unit? [Unknown]: US Is CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US correct? [no]: yes Enter key password for <certificatekey> (RETURN if same as keystore password): $ keytool -list -v -keystore src/test/resources/test.ks Enter keystore password: secret Keystore type: JKS Keystore provider: SUN Your keystore contains 1 entry Alias name: sitestcertkey Creation date: Feb 25, 2012 Entry type: PrivateKeyEntry Certificate chain length: 1 Certificate[1]: Owner: CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US Issuer: CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US Serial number: 4f491902 Valid from: Sat Feb 25 12:23:14 EST 2012 until: Mon Feb 01 12:23:14 EST 2112 Certificate fingerprints: MD5: 4F:A9:76:0E:A9:C0:A8:B7:26:E7:7E:C7:E8:22:1F:8B SHA1: 88:AC:9E:4D:29:0D:3A:59:3B:73:95:4A:E1:BB:D0:22:89:37:64:4C Signature algorithm name: SHA1withRSA Version: 3 ******************************************* ******************************************* $ keytool -export -alias sitestcertkey -keystore src/test/resources/test.ks -rfc -file src/test/resources/test.cer Enter keystore password: Certificate stored in file <src/test/resources/test.cer> $ keytool -import -alias sitestcertkey -file src/test/resources/test.cer -keystore src/test/resources/test.truststore.ks Enter keystore password: secret Re-enter new password: secret Owner: CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US Issuer: CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US Serial number: 4f491902 Valid from: Sat Feb 25 12:23:14 EST 2012 until: Mon Feb 01 12:23:14 EST 2112 Certificate fingerprints: MD5: 4F:A9:76:0E:A9:C0:A8:B7:26:E7:7E:C7:E8:22:1F:8B SHA1: 88:AC:9E:4D:29:0D:3A:59:3B:73:95:4A:E1:BB:D0:22:89:37:64:4C Signature algorithm name: SHA1withRSA Version: 3 Trust this certificate? [no]: yes Certificate was added to keystore $ keytool -list -v -keystore src/test/resources/test.truststore.ks Enter keystore password: secret Keystore type: JKS Keystore provider: SUN Your keystore contains 1 entry Alias name: sitestcertkey Creation date: Feb 25, 2012 Entry type: trustedCertEntry Owner: CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US Issuer: CN=Spring Integration, OU=SpringSource, O=VMware, L=Palo Alto, ST=CA, C=US Serial number: 4f491902 Valid from: Sat Feb 25 12:23:14 EST 2012 until: Mon Feb 01 12:23:14 EST 2112 Certificate fingerprints: MD5: 4F:A9:76:0E:A9:C0:A8:B7:26:E7:7E:C7:E8:22:1F:8B SHA1: 88:AC:9E:4D:29:0D:3A:59:3B:73:95:4A:E1:BB:D0:22:89:37:64:4C Signature algorithm name: SHA1withRSA Version: 3 ******************************************* ******************************************* */ @Test public void testNetClientAndServerSSL() throws Exception { System.setProperty("javax.net.debug", "all"); // SSL activity in the console TcpNetServerConnectionFactory server = new TcpNetServerConnectionFactory(0); TcpSSLContextSupport sslContextSupport = new DefaultTcpSSLContextSupport("test.ks", "test.truststore.ks", "secret", "secret"); DefaultTcpNetSSLSocketFactorySupport tcpSocketFactorySupport = new DefaultTcpNetSSLSocketFactorySupport(sslContextSupport); server.setTcpSocketFactorySupport(tcpSocketFactorySupport); final List<Message<?>> messages = new ArrayList<Message<?>>(); final CountDownLatch latch = new CountDownLatch(1); server.registerListener(message -> { messages.add(message); latch.countDown(); return false; }); server.setMapper(new SSLMapper()); server.start(); TestingUtilities.waitListening(server, null); TcpNetClientConnectionFactory client = new TcpNetClientConnectionFactory("localhost", server.getPort()); client.setTcpSocketFactorySupport(tcpSocketFactorySupport); client.start(); TcpConnection connection = client.getConnection(); connection.send(new GenericMessage<String>("Hello, world!")); assertTrue(latch.await(10, TimeUnit.SECONDS)); assertEquals("Hello, world!", new String((byte[]) messages.get(0).getPayload())); assertNotNull(messages.get(0).getHeaders().get("cipher")); client.stop(); server.stop(); } @Test public void testNetClientAndServerSSLDifferentContexts() throws Exception { testNetClientAndServerSSLDifferentContexts(false); try { testNetClientAndServerSSLDifferentContexts(true); fail("expected Exception"); } catch (SSLException | SocketException e) { // NOSONAR } } private void testNetClientAndServerSSLDifferentContexts(boolean badClient) throws Exception { System.setProperty("javax.net.debug", "all"); // SSL activity in the console TcpNetServerConnectionFactory server = new TcpNetServerConnectionFactory(0); TcpSSLContextSupport serverSslContextSupport = new DefaultTcpSSLContextSupport("server.ks", "server.truststore.ks", "secret", "secret"); DefaultTcpNetSSLSocketFactorySupport serverTcpSocketFactorySupport = new DefaultTcpNetSSLSocketFactorySupport(serverSslContextSupport); server.setTcpSocketFactorySupport(serverTcpSocketFactorySupport); final List<Message<?>> messages = new ArrayList<Message<?>>(); final CountDownLatch latch = new CountDownLatch(1); server.registerListener(message -> { messages.add(message); latch.countDown(); return false; }); server.setTcpSocketSupport(new DefaultTcpSocketSupport() { @Override public void postProcessServerSocket(ServerSocket serverSocket) { ((SSLServerSocket) serverSocket).setNeedClientAuth(true); } }); server.start(); TestingUtilities.waitListening(server, null); TcpNetClientConnectionFactory client = new TcpNetClientConnectionFactory("localhost", server.getPort()); TcpSSLContextSupport clientSslContextSupport = new DefaultTcpSSLContextSupport( badClient ? "server.ks" : "client.ks", "client.truststore.ks", "secret", "secret"); DefaultTcpNetSSLSocketFactorySupport clientTcpSocketFactorySupport = new DefaultTcpNetSSLSocketFactorySupport(clientSslContextSupport); client.setTcpSocketFactorySupport(clientTcpSocketFactorySupport); client.start(); TcpConnection connection = client.getConnection(); connection.send(new GenericMessage<String>("Hello, world!")); assertTrue(latch.await(10, TimeUnit.SECONDS)); assertEquals("Hello, world!", new String((byte[]) messages.get(0).getPayload())); client.stop(); server.stop(); } @Test public void testNioClientAndServerSSL() throws Exception { System.setProperty("javax.net.debug", "all"); // SSL activity in the console TcpNioServerConnectionFactory server = new TcpNioServerConnectionFactory(0); server.setSslHandshakeTimeout(43); DefaultTcpSSLContextSupport sslContextSupport = new DefaultTcpSSLContextSupport("test.ks", "test.truststore.ks", "secret", "secret"); sslContextSupport.setProtocol("SSL"); DefaultTcpNioSSLConnectionSupport tcpNioConnectionSupport = new DefaultTcpNioSSLConnectionSupport(sslContextSupport); server.setTcpNioConnectionSupport(tcpNioConnectionSupport); final List<Message<?>> messages = new ArrayList<Message<?>>(); final CountDownLatch latch = new CountDownLatch(1); server.registerListener(message -> { messages.add(message); latch.countDown(); return false; }); server.setMapper(new SSLMapper()); final AtomicReference<String> serverConnectionId = new AtomicReference<>(); server.setApplicationEventPublisher(e -> { if (e instanceof TcpConnectionOpenEvent) { serverConnectionId.set(((TcpConnectionEvent) e).getConnectionId()); } }); server.start(); TestingUtilities.waitListening(server, null); TcpNioClientConnectionFactory client = new TcpNioClientConnectionFactory("localhost", server.getPort()); client.setSslHandshakeTimeout(34); client.setTcpNioConnectionSupport(tcpNioConnectionSupport); client.registerListener(message -> false); client.setApplicationEventPublisher(e -> { }); client.start(); TcpConnection connection = client.getConnection(); assertEquals(34, TestUtils.getPropertyValue(connection, "handshakeTimeout")); connection.send(new GenericMessage<String>("Hello, world!")); assertTrue(latch.await(10, TimeUnit.SECONDS)); assertEquals("Hello, world!", new String((byte[]) messages.get(0).getPayload())); assertNotNull(messages.get(0).getHeaders().get("cipher")); Map<?, ?> connections = TestUtils.getPropertyValue(server, "connections", Map.class); Object serverConnection = connections.get(serverConnectionId.get()); assertNotNull(serverConnection); assertEquals(43, TestUtils.getPropertyValue(serverConnection, "handshakeTimeout")); client.stop(); server.stop(); } @Test public void testNioClientAndServerSSLDifferentContexts() throws Exception { testNioClientAndServerSSLDifferentContexts(false); try { testNioClientAndServerSSLDifferentContexts(true); fail("expected Exception"); } catch (IOException e) { if (!(e instanceof ClosedChannelException)) { assertThat(e.getMessage(), anyOf( containsString("Socket closed during SSL Handshake"), containsString("Broken pipe"))); } } } private void testNioClientAndServerSSLDifferentContexts(boolean badClient) throws Exception { System.setProperty("javax.net.debug", "all"); // SSL activity in the console TcpNioServerConnectionFactory server = new TcpNioServerConnectionFactory(0); TcpSSLContextSupport serverSslContextSupport = new DefaultTcpSSLContextSupport("server.ks", "server.truststore.ks", "secret", "secret"); DefaultTcpNioSSLConnectionSupport tcpNioConnectionSupport = new DefaultTcpNioSSLConnectionSupport(serverSslContextSupport) { @Override protected void postProcessSSLEngine(SSLEngine sslEngine) { sslEngine.setNeedClientAuth(true); } }; server.setTcpNioConnectionSupport(tcpNioConnectionSupport); final List<Message<?>> messages = new ArrayList<Message<?>>(); final CountDownLatch latch = new CountDownLatch(1); server.registerListener(message -> { messages.add(message); latch.countDown(); return false; }); server.start(); TestingUtilities.waitListening(server, null); TcpNioClientConnectionFactory client = new TcpNioClientConnectionFactory("localhost", server.getPort()); TcpSSLContextSupport clientSslContextSupport = new DefaultTcpSSLContextSupport( badClient ? "server.ks" : "client.ks", "client.truststore.ks", "secret", "secret"); DefaultTcpNioSSLConnectionSupport clientTcpNioConnectionSupport = new DefaultTcpNioSSLConnectionSupport(clientSslContextSupport); client.setTcpNioConnectionSupport(clientTcpNioConnectionSupport); client.start(); TcpConnection connection = client.getConnection(); connection.send(new GenericMessage<String>("Hello, world!")); assertTrue(latch.await(10, TimeUnit.SECONDS)); assertEquals("Hello, world!", new String((byte[]) messages.get(0).getPayload())); client.stop(); server.stop(); } @Test public void testNioClientAndServerSSLDifferentContextsLargeDataWithReply() throws Exception { System.setProperty("javax.net.debug", "all"); // SSL activity in the console TcpNioServerConnectionFactory server = new TcpNioServerConnectionFactory(0); TcpSSLContextSupport serverSslContextSupport = new DefaultTcpSSLContextSupport("server.ks", "server.truststore.ks", "secret", "secret"); DefaultTcpNioSSLConnectionSupport serverTcpNioConnectionSupport = new DefaultTcpNioSSLConnectionSupport(serverSslContextSupport); server.setTcpNioConnectionSupport(serverTcpNioConnectionSupport); final List<Message<?>> messages = new ArrayList<Message<?>>(); final CountDownLatch latch = new CountDownLatch(2); final Replier replier = new Replier(); server.registerSender(replier); server.registerListener(message -> { messages.add(message); try { replier.send(message); } catch (Exception e) { throw new RuntimeException(e); } latch.countDown(); return false; }); ByteArrayCrLfSerializer deserializer = new ByteArrayCrLfSerializer(); deserializer.setMaxMessageSize(120000); server.setDeserializer(deserializer); final AtomicReference<String> serverConnectionId = new AtomicReference<>(); server.setApplicationEventPublisher(e -> { if (e instanceof TcpConnectionOpenEvent) { serverConnectionId.set(((TcpConnectionEvent) e).getConnectionId()); } }); server.start(); TestingUtilities.waitListening(server, null); TcpNioClientConnectionFactory client = new TcpNioClientConnectionFactory("localhost", server.getPort()); TcpSSLContextSupport clientSslContextSupport = new DefaultTcpSSLContextSupport("client.ks", "client.truststore.ks", "secret", "secret"); DefaultTcpNioSSLConnectionSupport clientTcpNioConnectionSupport = new DefaultTcpNioSSLConnectionSupport(clientSslContextSupport); client.setTcpNioConnectionSupport(clientTcpNioConnectionSupport); client.registerListener(message -> { messages.add(message); latch.countDown(); return false; }); client.setDeserializer(deserializer); client.setApplicationEventPublisher(e -> { }); client.start(); TcpConnection connection = client.getConnection(); assertEquals(30, TestUtils.getPropertyValue(connection, "handshakeTimeout")); byte[] bytes = new byte[100000]; connection.send(new GenericMessage<String>("Hello, world!" + new String(bytes))); assertTrue(latch.await(60, TimeUnit.SECONDS)); byte[] payload = (byte[]) messages.get(0).getPayload(); assertEquals(13 + bytes.length, payload.length); assertEquals("Hello, world!", new String(payload).substring(0, 13)); payload = (byte[]) messages.get(1).getPayload(); assertEquals(13 + bytes.length, payload.length); assertEquals("Hello, world!", new String(payload).substring(0, 13)); Map<?, ?> connections = TestUtils.getPropertyValue(server, "connections", Map.class); Object serverConnection = connections.get(serverConnectionId.get()); assertNotNull(serverConnection); assertEquals(30, TestUtils.getPropertyValue(serverConnection, "handshakeTimeout")); client.stop(); server.stop(); } private static class Replier implements TcpSender { private TcpConnection connection; @Override public void addNewConnection(TcpConnection connection) { this.connection = connection; } @Override public void removeDeadConnection(TcpConnection connection) { } public void send(Message<?> message) throws Exception { // force a renegotiation from the server side SSLEngine sslEngine = TestUtils.getPropertyValue(this.connection, "sslEngine", SSLEngine.class); sslEngine.getSession().invalidate(); sslEngine.beginHandshake(); this.connection.send(message); } } private static class SSLMapper extends TcpMessageMapper { @Override protected Map<String, ?> supplyCustomHeaders(TcpConnection connection) { return Collections.singletonMap("cipher", connection.getSslSession().getCipherSuite()); } } }