/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.ip.tcp.connection; import java.io.BufferedOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import javax.net.ssl.SSLSession; import org.springframework.context.ApplicationEventPublisher; import org.springframework.core.serializer.Serializer; import org.springframework.integration.ip.tcp.serializer.SoftEndOfStreamException; import org.springframework.integration.util.CompositeExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessagingException; import org.springframework.util.Assert; /** * A TcpConnection that uses and underlying {@link SocketChannel}. * * @author Gary Russell * @author John Anderson * @since 2.0 * */ public class TcpNioConnection extends TcpConnectionSupport { private static final long DEFAULT_PIPE_TIMEOUT = 60000; private final SocketChannel socketChannel; private final ChannelOutputStream channelOutputStream; private final ChannelInputStream channelInputStream = new ChannelInputStream(); private volatile OutputStream bufferedOutputStream; private volatile boolean usingDirectBuffers; private volatile CompositeExecutor taskExecutor; private volatile ByteBuffer rawBuffer; private volatile int maxMessageSize = 60 * 1024; private volatile long lastRead; private volatile long lastSend; private final AtomicInteger executionControl = new AtomicInteger(); private volatile boolean writingToPipe; private volatile CountDownLatch writingLatch; private volatile long pipeTimeout = DEFAULT_PIPE_TIMEOUT; private volatile boolean timedOut; /** * Constructs a TcpNetConnection for the SocketChannel. * @param socketChannel The socketChannel. * @param server If true, this connection was created as * a result of an incoming request. * @param lookupHost true to perform reverse lookups. * @param applicationEventPublisher The event publisher. * @param connectionFactoryName The name of the connection factory creating this connection. * @throws Exception Any Exception. */ public TcpNioConnection(SocketChannel socketChannel, boolean server, boolean lookupHost, ApplicationEventPublisher applicationEventPublisher, String connectionFactoryName) throws Exception { super(socketChannel.socket(), server, lookupHost, applicationEventPublisher, connectionFactoryName); this.socketChannel = socketChannel; int receiveBufferSize = socketChannel.socket().getReceiveBufferSize(); if (receiveBufferSize <= 0) { receiveBufferSize = this.maxMessageSize; } this.channelOutputStream = new ChannelOutputStream(); } public void setPipeTimeout(long pipeTimeout) { this.pipeTimeout = pipeTimeout; } @Override public void close() { this.setNoReadErrorOnClose(true); doClose(); } private void doClose() { try { this.channelInputStream.close(); } catch (IOException e) { } try { this.socketChannel.close(); } catch (Exception e) { } super.close(); } @Override public boolean isOpen() { return this.socketChannel.isOpen(); } @Override @SuppressWarnings("unchecked") public void send(Message<?> message) throws Exception { synchronized (this.socketChannel) { if (this.bufferedOutputStream == null) { int writeBufferSize = this.socketChannel.socket().getSendBufferSize(); this.bufferedOutputStream = new BufferedOutputStream(this.getChannelOutputStream(), writeBufferSize > 0 ? writeBufferSize : 8192); } Object object = this.getMapper().fromMessage(message); this.lastSend = System.currentTimeMillis(); try { ((Serializer<Object>) this.getSerializer()).serialize(object, this.bufferedOutputStream); this.bufferedOutputStream.flush(); } catch (Exception e) { this.publishConnectionExceptionEvent(new MessagingException(message, "Failed TCP serialization", e)); this.closeConnection(true); throw e; } if (logger.isDebugEnabled()) { logger.debug(getConnectionId() + " Message sent " + message); } } } @Override public Object getPayload() throws Exception { return this.getDeserializer().deserialize(this.channelInputStream); } @Override public int getPort() { return this.socketChannel.socket().getPort(); } @Override public Object getDeserializerStateKey() { return this.channelInputStream; } @Override public SSLSession getSslSession() { return null; } /** * Allocates a ByteBuffer of the requested length using normal or * direct buffers, depending on the usingDirectBuffers field. * * @param length The buffer length. * @return The buffer. */ protected ByteBuffer allocate(int length) { ByteBuffer buffer; if (this.usingDirectBuffers) { buffer = ByteBuffer.allocateDirect(length); } else { buffer = ByteBuffer.allocate(length); } return buffer; } /** * If there is no listener, * this method exits. When there is a listener, this method assembles * data into messages by invoking convertAndSend whenever there is * data in the input Stream. Method exits when a message is complete * and there is no more data; thus freeing the thread to work on other * sockets. */ @Override public void run() { if (logger.isTraceEnabled()) { logger.trace(this.getConnectionId() + " Nio message assembler running..."); } boolean moreDataAvailable = true; while (moreDataAvailable) { try { try { if (dataAvailable()) { Message<?> message = convert(); if (dataAvailable()) { // there is more data in the pipe; run another assembler // to assemble the next message, while we send ours this.executionControl.incrementAndGet(); try { this.taskExecutor.execute2(this); } catch (RejectedExecutionException e) { this.executionControl.decrementAndGet(); if (logger.isInfoEnabled()) { logger.info(getConnectionId() + " Insufficient threads in the assembler fixed thread pool; consider " + "increasing this task executor pool size; data avail: " + this.channelInputStream.available()); } } } this.executionControl.decrementAndGet(); if (message != null) { sendToChannel(message); } } else { this.executionControl.decrementAndGet(); } } catch (Exception e) { if (logger.isTraceEnabled()) { logger.error("Read exception " + this.getConnectionId(), e); } else if (!this.isNoReadErrorOnClose()) { logger.error("Read exception " + this.getConnectionId() + " " + e.getClass().getSimpleName() + ":" + e.getCause() + ":" + e.getMessage()); } else { if (logger.isDebugEnabled()) { logger.debug("Read exception " + this.getConnectionId() + " " + e.getClass().getSimpleName() + ":" + e.getCause() + ":" + e.getMessage()); } } this.closeConnection(true); this.sendExceptionToListener(e); return; } } finally { moreDataAvailable = false; // Final check in case new data came in and the // timing was such that we were the last assembler and // a new one wasn't run try { if (dataAvailable()) { synchronized (this.executionControl) { if (this.executionControl.incrementAndGet() <= 1) { // only continue if we don't already have another assembler running this.executionControl.set(1); moreDataAvailable = true; } else { this.executionControl.decrementAndGet(); } } } if (moreDataAvailable) { if (logger.isTraceEnabled()) { logger.trace(this.getConnectionId() + " Nio message assembler continuing..."); } } else { if (logger.isTraceEnabled()) { logger.trace(this.getConnectionId() + " Nio message assembler exiting... avail: " + this.channelInputStream.available()); } } } catch (IOException e) { logger.error("Exception when checking for assembler", e); } } } } private boolean dataAvailable() throws IOException { if (logger.isTraceEnabled()) { logger.trace(getConnectionId() + " checking data avail: " + this.channelInputStream.available() + " pending: " + (this.writingToPipe)); } return this.writingToPipe || this.channelInputStream.available() > 0; } /** * Blocks until a complete message has been assembled. * Synchronized to avoid concurrency. * @return The Message or null if no data is available. * @throws IOException */ private synchronized Message<?> convert() throws Exception { if (logger.isTraceEnabled()) { logger.trace(getConnectionId() + " checking data avail (convert): " + this.channelInputStream.available() + " pending: " + (this.writingToPipe)); } if (this.channelInputStream.available() <= 0) { try { if (this.writingLatch.await(60, TimeUnit.SECONDS)) { if (this.channelInputStream.available() <= 0) { return null; } } else { // should never happen throw new IOException("Timed out waiting for IO"); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted waiting for IO"); } } Message<?> message = null; try { message = this.getMapper().toMessage(this); } catch (Exception e) { this.closeConnection(true); if (e instanceof SocketTimeoutException) { if (logger.isDebugEnabled()) { logger.debug("Closing socket after timeout " + this.getConnectionId()); } } else { if (!(e instanceof SoftEndOfStreamException)) { throw e; } } return null; } return message; } private void sendToChannel(Message<?> message) { try { if (message != null) { TcpListener listener = getListener(); if (listener == null) { throw new NoListenerException("No listener"); } listener.onMessage(message); } } catch (Exception e) { if (e instanceof NoListenerException) { // could also be thrown by an interceptor if (logger.isWarnEnabled()) { logger.warn("Unexpected message - no endpoint registered with connection: " + getConnectionId() + " - " + message); } } else { logger.error("Exception sending message: " + message, e); } } } private void doRead() throws Exception { if (this.rawBuffer == null) { this.rawBuffer = allocate(this.maxMessageSize); } this.writingLatch = new CountDownLatch(1); this.writingToPipe = true; try { if (this.taskExecutor == null) { ExecutorService executor = Executors.newCachedThreadPool(); this.taskExecutor = new CompositeExecutor(executor, executor); } // If there is no assembler running, start one checkForAssembler(); if (logger.isTraceEnabled()) { logger.trace("Before read:" + this.rawBuffer.position() + "/" + this.rawBuffer.limit()); } int len = this.socketChannel.read(this.rawBuffer); if (len < 0) { this.writingToPipe = false; this.closeConnection(true); } if (logger.isTraceEnabled()) { logger.trace("After read:" + this.rawBuffer.position() + "/" + this.rawBuffer.limit()); } this.rawBuffer.flip(); if (logger.isTraceEnabled()) { logger.trace("After flip:" + this.rawBuffer.position() + "/" + this.rawBuffer.limit()); } if (logger.isDebugEnabled()) { logger.debug("Read " + this.rawBuffer.limit() + " into raw buffer"); } this.sendToPipe(this.rawBuffer); } catch (RejectedExecutionException e) { throw e; } catch (Exception e) { this.publishConnectionExceptionEvent(e); throw e; } finally { this.writingToPipe = false; this.writingLatch.countDown(); } } protected void sendToPipe(ByteBuffer rawBuffer) throws IOException { Assert.notNull(rawBuffer, "rawBuffer cannot be null"); if (logger.isTraceEnabled()) { logger.trace(this.getConnectionId() + " Sending " + rawBuffer.limit() + " to pipe"); } this.channelInputStream.write(rawBuffer); rawBuffer.clear(); } private void checkForAssembler() { synchronized (this.executionControl) { if (this.executionControl.incrementAndGet() <= 1) { // only execute run() if we don't already have one running this.executionControl.set(1); if (logger.isDebugEnabled()) { logger.debug(this.getConnectionId() + " Running an assembler"); } try { this.taskExecutor.execute2(this); } catch (RejectedExecutionException e) { this.executionControl.decrementAndGet(); if (logger.isInfoEnabled()) { logger.info("Insufficient threads in the assembler fixed thread pool; consider increasing " + "this task executor pool size"); } throw e; } } else { this.executionControl.decrementAndGet(); } } } /** * Invoked by the factory when there is data to be read. */ public void readPacket() { if (logger.isDebugEnabled()) { logger.debug(this.getConnectionId() + " Reading..."); } try { doRead(); } catch (ClosedChannelException cce) { if (logger.isDebugEnabled()) { logger.debug(this.getConnectionId() + " Channel is closed"); } this.closeConnection(true); } catch (RejectedExecutionException e) { throw e; } catch (Exception e) { logger.error("Exception on Read " + this.getConnectionId() + " " + e.getMessage(), e); this.closeConnection(true); } } /** * Close the socket due to timeout. */ void timeout() { this.timedOut = true; this.closeConnection(true); } /** * * @param taskExecutor the taskExecutor to set */ public void setTaskExecutor(Executor taskExecutor) { if (taskExecutor instanceof CompositeExecutor) { this.taskExecutor = (CompositeExecutor) taskExecutor; } else { this.taskExecutor = new CompositeExecutor(taskExecutor, taskExecutor); } } /** * If true, connection will attempt to use direct buffers where * possible. * @param usingDirectBuffers the usingDirectBuffers to set. */ public void setUsingDirectBuffers(boolean usingDirectBuffers) { this.usingDirectBuffers = usingDirectBuffers; } protected boolean isUsingDirectBuffers() { return this.usingDirectBuffers; } protected ChannelOutputStream getChannelOutputStream() { return this.channelOutputStream; } /** * * @return Time of last read. */ public long getLastRead() { return this.lastRead; } /** * * @param lastRead The time of the last read. */ public void setLastRead(long lastRead) { this.lastRead = lastRead; } /** * @return the time of the last send */ public long getLastSend() { return this.lastSend; } /** * OutputStream to wrap a SocketChannel; implements timeout on write. * */ class ChannelOutputStream extends OutputStream { private Selector selector; private int soTimeout; @Override public void write(int b) throws IOException { byte[] bytes = new byte[1]; bytes[0] = (byte) b; ByteBuffer buffer = ByteBuffer.wrap(bytes); doWrite(buffer); } @Override public void close() throws IOException { doClose(); } @Override public void flush() throws IOException { } @Override public void write(byte[] b, int off, int len) throws IOException { ByteBuffer buffer = ByteBuffer.wrap(b, off, len); doWrite(buffer); } @Override public void write(byte[] b) throws IOException { ByteBuffer buffer = ByteBuffer.wrap(b); doWrite(buffer); } protected synchronized void doWrite(ByteBuffer buffer) throws IOException { if (logger.isDebugEnabled()) { logger.debug(getConnectionId() + " writing " + buffer.remaining()); } TcpNioConnection.this.socketChannel.write(buffer); int remaining = buffer.remaining(); if (remaining == 0) { return; } if (this.selector == null) { this.selector = Selector.open(); this.soTimeout = TcpNioConnection.this.socketChannel.socket().getSoTimeout(); } TcpNioConnection.this.socketChannel.register(this.selector, SelectionKey.OP_WRITE); while (remaining > 0) { int selectionCount = this.selector.select(this.soTimeout); if (selectionCount == 0) { throw new SocketTimeoutException("Timeout on write"); } this.selector.selectedKeys().clear(); TcpNioConnection.this.socketChannel.write(buffer); remaining = buffer.remaining(); } } } /** * Provides an InputStream to receive data from {@link SocketChannel#read(ByteBuffer)} * operations. Each new buffer is added to a BlockingQueue; when the reading thread * exhausts the current buffer, it retrieves the next from the queue. * Writes block for up to the pipeTimeout if 5 buffers are queued to be read. * */ class ChannelInputStream extends InputStream { private static final int BUFFER_LIMIT = 5; private final BlockingQueue<byte[]> buffers = new LinkedBlockingQueue<byte[]>(BUFFER_LIMIT); private volatile byte[] currentBuffer; private volatile int currentOffset; private final AtomicInteger available = new AtomicInteger(); private volatile boolean isClosed; @Override public int read(byte[] b, int off, int len) throws IOException { Assert.notNull(b, "byte[] cannot be null"); if (off < 0 || len < 0 || len > b.length - off) { throw new IndexOutOfBoundsException(); } else if (len == 0) { return 0; } int n = 0; while ((this.available.get() > 0 || n == 0) && n < len) { int bite = read(); if (bite < 0) { if (n == 0) { return -1; } else { return n; } } b[off + n++] = (byte) bite; } return n; } @Override public synchronized int read() throws IOException { if (this.isClosed && this.available.get() == 0) { if (TcpNioConnection.this.timedOut) { throw new SocketTimeoutException("Connection has timed out"); } return -1; } if (this.currentBuffer == null) { this.currentBuffer = getNextBuffer(); this.currentOffset = 0; if (this.currentBuffer == null) { if (TcpNioConnection.this.timedOut) { throw new SocketTimeoutException("Connection has timed out"); } return -1; } } int bite; bite = this.currentBuffer[this.currentOffset++] & 0xff; this.available.decrementAndGet(); if (this.currentOffset >= this.currentBuffer.length) { this.currentBuffer = null; } return bite; } private byte[] getNextBuffer() throws IOException { byte[] buffer = null; while (buffer == null) { try { buffer = this.buffers.poll(1, TimeUnit.SECONDS); if (buffer == null && this.isClosed) { return null; } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while waiting for data", e); } } return buffer; } /** * Blocks if the blocking queue already contains 5 buffers. * @param array * @param bytesToWrite * @throws IOException */ public void write(ByteBuffer byteBuffer) throws IOException { int bytesToWrite = byteBuffer.limit() - byteBuffer.position(); if (bytesToWrite > 0) { byte[] buffer = new byte[bytesToWrite]; byteBuffer.get(buffer); this.available.addAndGet(bytesToWrite); if (TcpNioConnection.this.writingLatch != null) { TcpNioConnection.this.writingLatch.countDown(); } try { if (!this.buffers.offer(buffer, TcpNioConnection.this.pipeTimeout, TimeUnit.MILLISECONDS)) { throw new IOException("Timed out waiting for buffer space"); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while waiting for buffer space", e); } TcpNioConnection.this.writingLatch = new CountDownLatch(1); } } @Override public void close() throws IOException { super.close(); this.isClosed = true; } @Override public int available() throws IOException { return this.available.get(); } } }