/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.ip.tcp.serializer; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.ServerSocket; import java.net.Socket; import java.nio.ByteBuffer; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import javax.net.ServerSocketFactory; import org.junit.Rule; import org.junit.Test; import org.springframework.beans.factory.BeanFactory; import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.core.serializer.DefaultDeserializer; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.ip.tcp.TcpInboundGateway; import org.springframework.integration.ip.tcp.TcpOutboundGateway; import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory; import org.springframework.integration.ip.tcp.connection.TcpNioServerConnectionFactory; import org.springframework.integration.ip.util.SocketTestUtils; import org.springframework.integration.ip.util.TestingUtilities; import org.springframework.integration.support.MessageBuilder; import org.springframework.integration.test.support.LongRunningIntegrationTest; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.support.GenericMessage; /** * @author Gary Russell * @author Gavin Gray * @since 2.0 */ public class DeserializationTests { @Rule public LongRunningIntegrationTest longRunningIntegrationTest = new LongRunningIntegrationTest(); @Test public void testReadLength() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendLength(port, null); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayLengthHeaderSerializer serializer = new ByteArrayLengthHeaderSerializer(); byte[] out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); server.close(); done.countDown(); } @Test public void testReadStxEtx() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendStxEtx(port, null); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayStxEtxSerializer serializer = new ByteArrayStxEtxSerializer(); byte[] out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); server.close(); done.countDown(); } @Test public void testReadCrLf() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendCrLf(port, null); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); byte[] out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); server.close(); done.countDown(); } @Test public void testReadRaw() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); SocketTestUtils.testSendRaw(port); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayRawSerializer serializer = new ByteArrayRawSerializer(); byte[] out = serializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING + SocketTestUtils.TEST_STRING, new String(out)); server.close(); } @Test public void testReadSerialized() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendSerialized(port); Socket socket = server.accept(); socket.setSoTimeout(5000); DefaultDeserializer deserializer = new DefaultDeserializer(); Object out = deserializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING, out); out = deserializer.deserialize(socket.getInputStream()); assertEquals("Data", SocketTestUtils.TEST_STRING, out); server.close(); done.countDown(); } @Test public void testReadLengthOverflow() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendLengthOverflow(port); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayLengthHeaderSerializer serializer = new ByteArrayLengthHeaderSerializer(); try { serializer.deserialize(socket.getInputStream()); fail("Expected message length exceeded exception"); } catch (IOException e) { if (!e.getMessage().startsWith("Message length")) { e.printStackTrace(); fail("Unexpected IO Error:" + e.getMessage()); } } server.close(); done.countDown(); } @Test public void testReadStxEtxTimeout() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendStxEtxOverflow(port); Socket socket = server.accept(); socket.setSoTimeout(500); ByteArrayStxEtxSerializer serializer = new ByteArrayStxEtxSerializer(); try { serializer.deserialize(socket.getInputStream()); fail("Expected timeout exception"); } catch (IOException e) { if (!e.getMessage().startsWith("Read timed out")) { e.printStackTrace(); fail("Unexpected IO Error:" + e.getMessage()); } } server.close(); done.countDown(); } @Test public void testReadStxEtxOverflow() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch done = SocketTestUtils.testSendStxEtxOverflow(port); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayStxEtxSerializer serializer = new ByteArrayStxEtxSerializer(); serializer.setMaxMessageSize(1024); try { serializer.deserialize(socket.getInputStream()); fail("Expected message length exceeded exception"); } catch (IOException e) { if (!e.getMessage().startsWith("ETX not found")) { e.printStackTrace(); fail("Unexpected IO Error:" + e.getMessage()); } } server.close(); done.countDown(); } @Test public void testReadCrLfTimeout() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch latch = SocketTestUtils.testSendCrLfOverflow(port); Socket socket = server.accept(); socket.setSoTimeout(500); ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); try { serializer.deserialize(socket.getInputStream()); fail("Expected timout exception"); } catch (IOException e) { if (!e.getMessage().startsWith("Read timed out")) { e.printStackTrace(); fail("Unexpected IO Error:" + e.getMessage()); } } server.close(); latch.countDown(); } @Test public void testReadCrLfOverflow() throws Exception { ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0); int port = server.getLocalPort(); server.setSoTimeout(10000); CountDownLatch latch = SocketTestUtils.testSendCrLfOverflow(port); Socket socket = server.accept(); socket.setSoTimeout(5000); ByteArrayCrLfSerializer serializer = new ByteArrayCrLfSerializer(); serializer.setMaxMessageSize(1024); try { serializer.deserialize(socket.getInputStream()); fail("Expected message length exceeded exception"); } catch (IOException e) { if (!e.getMessage().startsWith("CRLF not found")) { e.printStackTrace(); fail("Unexpected IO Error:" + e.getMessage()); } } server.close(); latch.countDown(); } @Test public void canDeserializeMultipleSubsequentTerminators() throws IOException { byte terminator = (byte) '\n'; ByteArraySingleTerminatorSerializer serializer = new ByteArraySingleTerminatorSerializer(terminator); ByteArrayInputStream inputStream = new ByteArrayInputStream("s\n\n".getBytes()); try { byte[] bytes = serializer.deserialize(inputStream); assertEquals(1, bytes.length); assertEquals("s".getBytes()[0], bytes[0]); bytes = serializer.deserialize(inputStream); assertEquals(0, bytes.length); } finally { inputStream.close(); } } @Test public void deserializationEvents() throws Exception { doDeserialize(new ByteArrayCrLfSerializer(), "CRLF not found before max message length: 5"); doDeserialize(new ByteArrayLengthHeaderSerializer(), "Message length 1718579042 exceeds max message length: 5"); TcpDeserializationExceptionEvent event = doDeserialize(new ByteArrayLengthHeaderSerializer(), "Stream closed after 3 of 4", new byte[] { 0, 0, 0 }, 5); // closed during header read assertEquals(-1, event.getOffset()); assertEquals(new String(new byte[] { 0, 0, 0 }), new String(event.getBuffer()).substring(0, 3)); event = doDeserialize(new ByteArrayLengthHeaderSerializer(), "Stream closed after 1 of 2", new byte[] { 0, 0, 0, 2, 7 }, 5); // closed during data read assertEquals(-1, event.getOffset()); assertEquals(new String(new byte[] { 7 }), new String(event.getBuffer()).substring(0, 1)); doDeserialize(new ByteArrayLfSerializer(), "Terminator '0xa' not found before max message length: 5"); doDeserialize(new ByteArrayRawSerializer(), "Socket was not closed before max message length: 5"); doDeserialize(new ByteArraySingleTerminatorSerializer((byte) 0xfe), "Terminator '0xfe' not found before max message length: 5"); doDeserialize(new ByteArrayStxEtxSerializer(), "Expected STX to begin message"); event = doDeserialize(new ByteArrayStxEtxSerializer(), "Socket closed during message assembly", new byte[] { 0x02, 0, 0 }, 5); assertEquals(2, event.getOffset()); } private TcpDeserializationExceptionEvent doDeserialize(AbstractByteArraySerializer deser, String expectedMessage) { return doDeserialize(deser, expectedMessage, "foobar".getBytes(), 5); } private TcpDeserializationExceptionEvent doDeserialize(AbstractByteArraySerializer deser, String expectedMessage, byte[] data, int mms) { final AtomicReference<TcpDeserializationExceptionEvent> event = new AtomicReference<TcpDeserializationExceptionEvent>(); class Publisher implements ApplicationEventPublisher { @Override public void publishEvent(ApplicationEvent anEvent) { event.set((TcpDeserializationExceptionEvent) anEvent); } @Override public void publishEvent(Object event) { } } Publisher publisher = new Publisher(); ByteArrayInputStream bais = new ByteArrayInputStream(data); deser.setApplicationEventPublisher(publisher); deser.setMaxMessageSize(mms); try { deser.deserialize(bais); fail("expected exception"); } catch (Exception e) { assertNotNull(event.get()); assertSame(e, event.get().getCause()); assertThat(e.getMessage(), containsString(expectedMessage)); } return event.get(); } @Test public void testTimeoutWithCustomDeserializer() throws Exception { testTimeoutWhileDecoding(new CustomDeserializer(), "\u0000\u0002\u0000\u0005reply"); } @Test public void testTimeoutWithRawDeserializer() throws Exception { testTimeoutWhileDecoding(new ByteArrayRawSerializer(), "reply"); } public void testTimeoutWhileDecoding(AbstractByteArraySerializer deserializer, String reply) throws Exception { ByteArrayRawSerializer serializer = new ByteArrayRawSerializer(); TcpNioServerConnectionFactory serverNio = new TcpNioServerConnectionFactory(0); ByteArrayLengthHeaderSerializer lengthHeaderSerializer = new ByteArrayLengthHeaderSerializer(1); serverNio.setDeserializer(lengthHeaderSerializer); serverNio.setSerializer(serializer); serverNio.afterPropertiesSet(); TcpInboundGateway in = new TcpInboundGateway(); in.setConnectionFactory(serverNio); QueueChannel serverSideChannel = new QueueChannel(); in.setRequestChannel(serverSideChannel); in.setBeanFactory(mock(BeanFactory.class)); in.afterPropertiesSet(); in.start(); TestingUtilities.waitListening(serverNio, null); TcpNioClientConnectionFactory clientNio = new TcpNioClientConnectionFactory("localhost", serverNio.getPort()); clientNio.setSerializer(serializer); clientNio.setDeserializer(deserializer); clientNio.setSoTimeout(1000); clientNio.afterPropertiesSet(); final TcpOutboundGateway out = new TcpOutboundGateway(); out.setConnectionFactory(clientNio); QueueChannel outputChannel = new QueueChannel(); out.setOutputChannel(outputChannel); out.setRemoteTimeout(60000); out.setBeanFactory(mock(BeanFactory.class)); out.afterPropertiesSet(); out.start(); Runnable command = () -> { try { out.handleMessage(MessageBuilder.withPayload("\u0004Test").build()); } catch (Exception e) { // eat SocketTimeoutException. Doesn't matter for this test } }; ExecutorService exec = Executors.newSingleThreadExecutor(); Message<?> message; // short reply should not be received. exec.execute(command); message = serverSideChannel.receive(10000); assertNotNull(message); assertEquals("Test", new String((byte[]) message.getPayload())); String shortReply = reply.substring(0, reply.length() - 1); ((MessageChannel) message.getHeaders().getReplyChannel()).send(new GenericMessage<String>(shortReply)); message = outputChannel.receive(6000); assertNull(message); // good message should be received if ((deserializer instanceof ByteArrayRawSerializer)) { // restore old behavior clientNio.setDeserializer(new ByteArrayRawSerializer(true)); } exec.execute(command); message = serverSideChannel.receive(10000); assertNotNull(message); assertEquals("Test", new String((byte[]) message.getPayload())); ((MessageChannel) message.getHeaders().getReplyChannel()).send(new GenericMessage<String>(reply)); message = outputChannel.receive(10000); assertNotNull(message); assertEquals(reply, new String(((byte[]) message.getPayload()))); } private static class CustomDeserializer extends AbstractByteArraySerializer { @Override public byte[] deserialize(InputStream inputStream) throws IOException { if (logger.isDebugEnabled()) { logger.debug("Available to read:" + inputStream.available()); } byte[] header = new byte[2]; header[0] = (byte) inputStream.read(); if (header[0] < 0) { throw new SoftEndOfStreamException("Stream closed between payloads"); } header[1] = (byte) inputStream.read(); if (header[1] < 0) { checkClosure(-1); } ByteBuffer headerBB = ByteBuffer.wrap(header); int val = headerBB.getShort(); byte[] length = new byte[val]; for (int i = 0; i < val; i++) { length[i] = (byte) inputStream.read(); } ByteBuffer lengthBB = ByteBuffer.wrap(length); int messageLength; if (val == 2) { messageLength = lengthBB.getShort(); } else if (val == 4) { messageLength = lengthBB.getInt(); } else { throw new IOException("Unexpected count of bytes that holds message length"); } byte[] answer = new byte[messageLength]; for (int i = 0; i < messageLength; i++) { int bite = inputStream.read(); if (bite < 0) { checkClosure(-1); } answer[i] = (byte) bite; } ByteBuffer b = ByteBuffer.allocate(2 + val + messageLength); b.put(header); b.put(length); b.put(answer); return b.array(); } @Override public void serialize(byte[] object, OutputStream outputStream) throws IOException { } } }