/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.ip.tcp.connection; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.net.Socket; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.net.SocketFactory; import org.junit.Test; import org.springframework.integration.ip.tcp.serializer.AbstractByteArraySerializer; import org.springframework.integration.ip.tcp.serializer.ByteArrayCrLfSerializer; import org.springframework.integration.ip.tcp.serializer.ByteArrayLengthHeaderSerializer; import org.springframework.integration.ip.tcp.serializer.ByteArrayStxEtxSerializer; import org.springframework.integration.ip.util.SocketTestUtils; import org.springframework.integration.ip.util.TestingUtilities; import org.springframework.messaging.Message; import org.springframework.messaging.support.ErrorMessage; /** * @author Gary Russell * @author Artem Bilan * * @since 2.0 */ public class TcpNioConnectionReadTests { private final CountDownLatch latch = new CountDownLatch(1); private AbstractServerConnectionFactory getConnectionFactory( AbstractByteArraySerializer serializer, TcpListener listener) throws Exception { return getConnectionFactory(serializer, listener, null); } private AbstractServerConnectionFactory getConnectionFactory( AbstractByteArraySerializer serializer, TcpListener listener, TcpSender sender) throws Exception { TcpNioServerConnectionFactory scf = new TcpNioServerConnectionFactory(0); scf.setUsingDirectBuffers(true); scf.setApplicationEventPublisher(e -> { }); scf.setSerializer(serializer); scf.setDeserializer(serializer); scf.registerListener(listener); if (sender != null) { scf.registerSender(sender); } scf.start(); TestingUtilities.waitListening(scf, null); return scf; } @Test public void testReadLength() throws Exception { ByteArrayLengthHeaderSerializer serializer = new ByteArrayLengthHeaderSerializer(); final List<Message<?>> responses = new ArrayList<Message<?>>(); final Semaphore semaphore = new Semaphore(0); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { responses.add(message); semaphore.release(); return false; }); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendLength(scf.getPort(), latch); latch.countDown(); assertTrue(semaphore.tryAcquire(1, 10000, TimeUnit.MILLISECONDS)); assertTrue(semaphore.tryAcquire(1, 10000, TimeUnit.MILLISECONDS)); assertEquals("Did not receive data", 2, responses.size()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String((byte[]) responses.get(0).getPayload())); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String((byte[]) responses.get(1).getPayload())); scf.stop(); done.countDown(); } @SuppressWarnings("unchecked") @Test public void testFragmented() throws Exception { ByteArrayLengthHeaderSerializer serializer = new ByteArrayLengthHeaderSerializer(); final List<Message<?>> responses = new ArrayList<Message<?>>(); final Semaphore semaphore = new Semaphore(0); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { responses.add(message); try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } semaphore.release(); return false; }); int howMany = 2; scf.setBacklog(howMany + 5); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendFragmented(scf.getPort(), howMany, false); assertTrue(semaphore.tryAcquire(howMany, 20000, TimeUnit.MILLISECONDS)); assertEquals("Expected", howMany, responses.size()); for (int i = 0; i < howMany; i++) { assertEquals("Data", "xx", new String(((Message<byte[]>) responses.get(0)).getPayload())); } scf.stop(); done.countDown(); } @SuppressWarnings("unchecked") @Test public void testReadStxEtx() throws Exception { ByteArrayStxEtxSerializer serializer = new ByteArrayStxEtxSerializer(); final List<Message<?>> responses = new ArrayList<Message<?>>(); final Semaphore semaphore = new Semaphore(0); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { responses.add(message); semaphore.release(); return false; }); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendStxEtx(scf.getPort(), latch); latch.countDown(); assertTrue(semaphore.tryAcquire(1, 10000, TimeUnit.MILLISECONDS)); assertTrue(semaphore.tryAcquire(1, 10000, TimeUnit.MILLISECONDS)); assertEquals("Did not receive data", 2, responses.size()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(((Message<byte[]>) responses.get(0)).getPayload())); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(((Message<byte[]>) responses.get(1)).getPayload())); scf.stop(); done.countDown(); } @SuppressWarnings("unchecked") @Test public void testReadCrLf() throws Exception { ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); final List<Message<?>> responses = new ArrayList<Message<?>>(); final Semaphore semaphore = new Semaphore(0); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { responses.add(message); semaphore.release(); return false; }); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendCrLf(scf.getPort(), latch); latch.countDown(); assertTrue(semaphore.tryAcquire(1, 10000, TimeUnit.MILLISECONDS)); assertTrue(semaphore.tryAcquire(1, 10000, TimeUnit.MILLISECONDS)); assertEquals("Did not receive data", 2, responses.size()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(((Message<byte[]>) responses.get(0)).getPayload())); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(((Message<byte[]>) responses.get(1)).getPayload())); scf.stop(); done.countDown(); } @Test public void testReadLengthOverflow() throws Exception { ByteArrayLengthHeaderSerializer serializer = new ByteArrayLengthHeaderSerializer(); final Semaphore semaphore = new Semaphore(0); final List<TcpConnection> added = new ArrayList<TcpConnection>(); final List<TcpConnection> removed = new ArrayList<TcpConnection>(); final CountDownLatch errorMessageLetch = new CountDownLatch(1); final AtomicReference<Throwable> errorMessageRef = new AtomicReference<Throwable>(); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { if (message instanceof ErrorMessage) { errorMessageRef.set(((ErrorMessage) message).getPayload()); errorMessageLetch.countDown(); } return false; }, new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { added.add(connection); semaphore.release(); } @Override public void removeDeadConnection(TcpConnection connection) { removed.add(connection); semaphore.release(); } }); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendLengthOverflow(scf.getPort()); whileOpen(semaphore, added); assertEquals(1, added.size()); assertTrue(errorMessageLetch.await(10, TimeUnit.SECONDS)); assertThat(errorMessageRef.get().getMessage(), anyOf(containsString("Message length 2147483647 exceeds max message length: 2048"), containsString("Connection is closed"))); assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); assertTrue(removed.size() > 0); scf.stop(); done.countDown(); } @Test public void testReadStxEtxOverflow() throws Exception { ByteArrayStxEtxSerializer serializer = new ByteArrayStxEtxSerializer(); serializer.setMaxMessageSize(1024); final Semaphore semaphore = new Semaphore(0); final List<TcpConnection> added = new ArrayList<TcpConnection>(); final List<TcpConnection> removed = new ArrayList<TcpConnection>(); final CountDownLatch errorMessageLetch = new CountDownLatch(1); final AtomicReference<Throwable> errorMessageRef = new AtomicReference<Throwable>(); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { if (message instanceof ErrorMessage) { errorMessageRef.set(((ErrorMessage) message).getPayload()); errorMessageLetch.countDown(); } return false; }, new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { added.add(connection); semaphore.release(); } @Override public void removeDeadConnection(TcpConnection connection) { removed.add(connection); semaphore.release(); } }); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendStxEtxOverflow(scf.getPort()); whileOpen(semaphore, added); assertEquals(1, added.size()); assertTrue(errorMessageLetch.await(10, TimeUnit.SECONDS)); assertThat(errorMessageRef.get().getMessage(), anyOf(containsString("Connection is closed"), containsString("ETX not found before max message length: 1024"))); assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); assertTrue(removed.size() > 0); scf.stop(); done.countDown(); } @Test public void testReadCrLfOverflow() throws Exception { ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); serializer.setMaxMessageSize(1024); final Semaphore semaphore = new Semaphore(0); final List<TcpConnection> added = new ArrayList<TcpConnection>(); final List<TcpConnection> removed = new ArrayList<TcpConnection>(); final CountDownLatch errorMessageLetch = new CountDownLatch(1); final AtomicReference<Throwable> errorMessageRef = new AtomicReference<Throwable>(); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { if (message instanceof ErrorMessage) { errorMessageRef.set(((ErrorMessage) message).getPayload()); errorMessageLetch.countDown(); } return false; }, new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { added.add(connection); semaphore.release(); } @Override public void removeDeadConnection(TcpConnection connection) { removed.add(connection); semaphore.release(); } }); // Fire up the sender. CountDownLatch done = SocketTestUtils.testSendCrLfOverflow(scf.getPort()); whileOpen(semaphore, added); assertEquals(1, added.size()); assertTrue(errorMessageLetch.await(10, TimeUnit.SECONDS)); assertThat(errorMessageRef.get().getMessage(), anyOf(containsString("Connection is closed"), containsString("CRLF not found before max message length: 1024"))); assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); assertTrue(removed.size() > 0); scf.stop(); done.countDown(); } /** * Tests socket closure when no data received. * @throws Exception */ @Test public void testCloseCleanupNoData() throws Exception { ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); serializer.setMaxMessageSize(1024); final Semaphore semaphore = new Semaphore(0); final List<TcpConnection> added = new ArrayList<TcpConnection>(); final List<TcpConnection> removed = new ArrayList<TcpConnection>(); final CountDownLatch errorMessageLetch = new CountDownLatch(1); final AtomicReference<Throwable> errorMessageRef = new AtomicReference<Throwable>(); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { if (message instanceof ErrorMessage) { errorMessageRef.set(((ErrorMessage) message).getPayload()); errorMessageLetch.countDown(); } return false; }, new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { added.add(connection); semaphore.release(); } @Override public void removeDeadConnection(TcpConnection connection) { removed.add(connection); semaphore.release(); } }); Socket socket = SocketFactory.getDefault().createSocket("localhost", scf.getPort()); socket.close(); whileOpen(semaphore, added); assertEquals(1, added.size()); assertTrue(errorMessageLetch.await(10, TimeUnit.SECONDS)); assertThat(errorMessageRef.get().getMessage(), anyOf(containsString("Connection is closed"), containsString("Stream closed after 2 of 3"))); assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); assertTrue(removed.size() > 0); scf.stop(); } /** * Tests socket closure when no data received. * @throws Exception */ @Test public void testCloseCleanupPartialData() throws Exception { ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); serializer.setMaxMessageSize(1024); final Semaphore semaphore = new Semaphore(0); final List<TcpConnection> added = new ArrayList<TcpConnection>(); final List<TcpConnection> removed = new ArrayList<TcpConnection>(); final CountDownLatch errorMessageLetch = new CountDownLatch(1); final AtomicReference<Throwable> errorMessageRef = new AtomicReference<Throwable>(); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { if (message instanceof ErrorMessage) { errorMessageRef.set(((ErrorMessage) message).getPayload()); errorMessageLetch.countDown(); } return false; }, new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { added.add(connection); semaphore.release(); } @Override public void removeDeadConnection(TcpConnection connection) { removed.add(connection); semaphore.release(); } }); Socket socket = SocketFactory.getDefault().createSocket("localhost", scf.getPort()); socket.getOutputStream().write("partial".getBytes()); socket.close(); whileOpen(semaphore, added); assertEquals(1, added.size()); assertTrue(errorMessageLetch.await(10, TimeUnit.SECONDS)); assertThat(errorMessageRef.get().getMessage(), anyOf(containsString("Connection is closed"), containsString("Socket closed during message assembly"))); assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); assertTrue(removed.size() > 0); scf.stop(); } /** * Tests socket closure when mid-message * @throws Exception */ @Test public void testCloseCleanupCrLf() throws Exception { ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); testClosureMidMessageGuts(serializer, "xx"); } /** * Tests socket closure when mid-message * @throws Exception */ @Test public void testCloseCleanupStxEtx() throws Exception { ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); testClosureMidMessageGuts(serializer, ByteArrayStxEtxSerializer.STX + "xx"); } /** * Tests socket closure when mid-message * @throws Exception */ @Test public void testCloseCleanupLengthHeader() throws Exception { ByteArrayLengthHeaderSerializer serializer = new ByteArrayLengthHeaderSerializer(); testClosureMidMessageGuts(serializer, "\u0000\u0000\u0000\u0003xx"); } private void testClosureMidMessageGuts(AbstractByteArraySerializer serializer, String shortMessage) throws Exception { final Semaphore semaphore = new Semaphore(0); final List<TcpConnection> added = new ArrayList<TcpConnection>(); final List<TcpConnection> removed = new ArrayList<TcpConnection>(); final CountDownLatch errorMessageLetch = new CountDownLatch(1); final AtomicReference<Throwable> errorMessageRef = new AtomicReference<Throwable>(); AbstractServerConnectionFactory scf = getConnectionFactory(serializer, message -> { if (message instanceof ErrorMessage) { errorMessageRef.set(((ErrorMessage) message).getPayload()); errorMessageLetch.countDown(); } return false; }, new TcpSender() { @Override public void addNewConnection(TcpConnection connection) { added.add(connection); semaphore.release(); } @Override public void removeDeadConnection(TcpConnection connection) { removed.add(connection); semaphore.release(); } }); Socket socket = SocketFactory.getDefault().createSocket("localhost", scf.getPort()); socket.getOutputStream().write(shortMessage.getBytes()); socket.close(); whileOpen(semaphore, added); assertEquals(1, added.size()); assertTrue(errorMessageLetch.await(10, TimeUnit.SECONDS)); assertThat(errorMessageRef.get().getMessage(), anyOf(containsString("Connection is closed"), containsString("Socket closed during message assembly"), containsString("Stream closed after 2 of 3"))); assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); assertTrue(removed.size() > 0); scf.stop(); } private void whileOpen(Semaphore semaphore, final List<TcpConnection> added) throws InterruptedException { int n = 0; assertTrue(semaphore.tryAcquire(10000, TimeUnit.MILLISECONDS)); while (added.get(0).isOpen()) { Thread.sleep(50); if (n++ > 200) { fail("Failed to close socket"); } } } }