/* * JBoss, Home of Professional Open Source * Copyright 2011, JBoss Inc., and individual contributors as indicated * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jboss.remoting3.remote; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import org.jboss.logging.Logger; import org.jboss.remoting3.RemotingOptions; import org.jboss.remoting3._private.Messages; import org.jboss.remoting3.spi.ConnectionHandlerFactory; import org.wildfly.security.auth.server.SecurityIdentity; import org.xnio.Buffers; import org.xnio.ByteBufferPool; import org.xnio.ChannelListener; import org.xnio.Connection; import org.xnio.IoUtils; import org.xnio.OptionMap; import org.xnio.Pooled; import org.xnio.Result; import org.xnio.StreamConnection; import org.xnio.XnioExecutor; import org.xnio.channels.SslChannel; import org.xnio.conduits.ConduitStreamSinkChannel; import org.xnio.conduits.ConduitStreamSourceChannel; import org.xnio.sasl.SaslWrapper; /** * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a> */ final class RemoteConnection { static final Pooled<ByteBuffer> STARTTLS_SENTINEL = Buffers.emptyPooledByteBuffer(); private static final String FQCN = RemoteConnection.class.getName(); private final StreamConnection connection; private final MessageReader messageReader; private final SslChannel sslChannel; private final OptionMap optionMap; private final RemoteWriteListener writeListener = new RemoteWriteListener(); private final Executor executor; private final int heartbeatInterval; private volatile Result<ConnectionHandlerFactory> result; private volatile SaslWrapper saslWrapper; private volatile SecurityIdentity identity; private final RemoteConnectionProvider remoteConnectionProvider; RemoteConnection(final StreamConnection connection, final SslChannel sslChannel, final OptionMap optionMap, final RemoteConnectionProvider remoteConnectionProvider) { this.connection = connection; this.messageReader = new MessageReader(connection.getSourceChannel(), writeListener.queue); this.sslChannel = sslChannel; this.optionMap = optionMap; heartbeatInterval = optionMap.get(RemotingOptions.HEARTBEAT_INTERVAL, RemotingOptions.DEFAULT_HEARTBEAT_INTERVAL); Messages.conn.tracef("Initialized connection from %s to %s with options %s", connection.getPeerAddress(), connection.getLocalAddress(), optionMap); this.executor = remoteConnectionProvider.getExecutor(); this.remoteConnectionProvider = remoteConnectionProvider; } Pooled<ByteBuffer> allocate() { return Buffers.globalPooledWrapper(ByteBufferPool.MEDIUM_DIRECT.allocate()); } void setReadListener(ChannelListener<ConduitStreamSourceChannel> listener, final boolean resume) { Messages.log.logf(RemoteConnection.class.getName(), Logger.Level.TRACE, null, "Setting read listener to %s", listener); messageReader.setReadListener(listener); if (listener != null && resume) { messageReader.resumeReads(); } } RemoteConnectionProvider getRemoteConnectionProvider() { return remoteConnectionProvider; } Result<ConnectionHandlerFactory> getResult() { return result; } void setResult(final Result<ConnectionHandlerFactory> result) { this.result = result; } void handleException(IOException e) { handleException(e, true); } void handleException(IOException e, boolean log) { Messages.conn.logf(RemoteConnection.class.getName(), Logger.Level.TRACE, e, "Connection error detail"); if (log) { Messages.conn.connectionError(e); } final XnioExecutor.Key key = writeListener.heartKey; if (key != null) { key.remove(); } synchronized (getLock()) { IoUtils.safeClose(connection); } final Result<ConnectionHandlerFactory> result = this.result; if (result != null) { result.setException(e); this.result = null; } } void send(final Pooled<ByteBuffer> pooled) { writeListener.send(pooled, false); } void send(final Pooled<ByteBuffer> pooled, boolean close) { writeListener.send(pooled, close); } void shutdownWrites() { writeListener.shutdownWrites(); } OptionMap getOptionMap() { return optionMap; } MessageReader getMessageReader() { return messageReader; } RemoteWriteListener getWriteListener() { return writeListener; } public Executor getExecutor() { return executor; } public SslChannel getSslChannel() { return sslChannel; } SaslWrapper getSaslWrapper() { return saslWrapper; } void setSaslWrapper(final SaslWrapper saslWrapper) { this.saslWrapper = saslWrapper; } void handlePreAuthCloseRequest() { try { terminateHeartbeat(); synchronized (getLock()) { connection.close(); } } catch (IOException e) { Messages.conn.debug("Error closing remoting channel", e); } } void sendAlive() { Messages.conn.trace("Sending connection alive"); final Pooled<ByteBuffer> pooled = allocate(); boolean ok = false; try { final ByteBuffer buffer = pooled.getResource(); buffer.put(Protocol.CONNECTION_ALIVE); buffer.limit(80); Buffers.addRandom(buffer); buffer.flip(); send(pooled); ok = true; messageReader.wakeupReads(); } finally { if (! ok) pooled.free(); } } void sendAliveResponse() { Messages.conn.trace("Sending connection alive ack"); final Pooled<ByteBuffer> pooled = allocate(); boolean ok = false; try { final ByteBuffer buffer = pooled.getResource(); buffer.put(Protocol.CONNECTION_ALIVE_ACK); buffer.limit(80); Buffers.addRandom(buffer); buffer.flip(); send(pooled); ok = true; } finally { if (! ok) pooled.free(); } } void terminateHeartbeat() { final XnioExecutor.Key key = writeListener.heartKey; if (key != null) { key.remove(); } } Object getLock() { return writeListener.queue; } SecurityIdentity getIdentity() { return identity; } void setIdentity(final SecurityIdentity identity) { this.identity = identity; } InetSocketAddress getPeerAddress() { return connection.getPeerAddress(InetSocketAddress.class); } InetSocketAddress getLocalAddress() { return connection.getLocalAddress(InetSocketAddress.class); } Connection getConnection() { return connection; } final class RemoteWriteListener implements ChannelListener<ConduitStreamSinkChannel> { private final Queue<Pooled<ByteBuffer>> queue = new ArrayDeque<Pooled<ByteBuffer>>(); private XnioExecutor.Key heartKey; private boolean closed; private ByteBuffer headerBuffer = ByteBuffer.allocateDirect(4); private final ByteBuffer[] cachedArray = new ByteBuffer[] { headerBuffer, null }; RemoteWriteListener() { } public void handleEvent(final ConduitStreamSinkChannel channel) { final ByteBuffer[] cachedArray = this.cachedArray; synchronized (queue) { Pooled<ByteBuffer> pooled; final Queue<Pooled<ByteBuffer>> queue = this.queue; try { ByteBuffer buffer = cachedArray[1]; if (buffer != null) { channel.write(cachedArray); if (buffer.hasRemaining()) { return; } } cachedArray[1] = null; while ((pooled = queue.peek()) != null) { buffer = pooled.getResource(); if (buffer.hasRemaining()) { // no empty messages headerBuffer.putInt(0, buffer.remaining()); headerBuffer.position(0); cachedArray[1] = buffer; final long res = channel.write(cachedArray); Messages.conn.tracef("Sent %d bytes", res); if (buffer.hasRemaining()) { // try again later return; } else { cachedArray[1] = null; queue.poll().free(); } } else { if (pooled == STARTTLS_SENTINEL) { if (channel.flush()) { Messages.conn.trace("Flushed channel"); final SslChannel sslChannel = getSslChannel(); assert sslChannel != null; // because STARTTLS would be false in this case sslChannel.startHandshake(); } else { // try again later Messages.conn.trace("Flush stalled"); return; } } // otherwise skip other empty message rather than try and write it queue.poll().free(); } } if (channel.flush()) { Messages.conn.trace("Flushed channel"); if (closed) { terminateHeartbeat(); // End of queue reached; shut down and try to flush the remainder channel.shutdownWrites(); if (channel.flush()) { Messages.conn.trace("Shut down writes on channel"); return; } // either this is successful and no more notifications will come, or not and it will be retried // either way we're done here return; } else { if (heartbeatInterval != Integer.MAX_VALUE) { this.heartKey = channel.getWriteThread().executeAfter(heartbeatCommand, heartbeatInterval, TimeUnit.MILLISECONDS); } } channel.suspendWrites(); } } catch (IOException e) { handleException(e, false); while ((pooled = queue.poll()) != null) { pooled.free(); } } // else try again later } } public void shutdownWrites() { synchronized (queue) { closed = true; terminateHeartbeat(); final ConduitStreamSinkChannel sinkChannel = connection.getSinkChannel(); try { if (! queue.isEmpty()) { sinkChannel.resumeWrites(); return; } sinkChannel.shutdownWrites(); if (! sinkChannel.flush()) { sinkChannel.resumeWrites(); return; } Messages.conn.logf(FQCN, Logger.Level.TRACE, null, "Shut down writes on channel"); } catch (IOException e) { handleException(e, false); Pooled<ByteBuffer> unqueued; while ((unqueued = queue.poll()) != null) { unqueued.free(); } } } } public void send(final Pooled<ByteBuffer> pooled, final boolean close) { connection.getIoThread().execute(() -> { synchronized (queue) { XnioExecutor.Key heartKey1 = RemoteWriteListener.this.heartKey; if (heartKey1 != null) heartKey1.remove(); if (closed) { pooled.free(); return; } if (close) { closed = true; } boolean free = true; try { final SaslWrapper wrapper = saslWrapper; if (wrapper != null) { final ByteBuffer buffer = pooled.getResource(); final ByteBuffer source = buffer.duplicate(); buffer.clear(); wrapper.wrap(buffer, source); buffer.flip(); } final boolean empty = queue.isEmpty(); queue.add(pooled); free = false; if (empty) { connection.getSinkChannel().resumeWrites(); } } catch (IOException e) { handleException(e, false); Pooled<ByteBuffer> unqueued; while ((unqueued = queue.poll()) != null) { unqueued.free(); } } finally { if (free) { pooled.free(); } } } }); } } private final Runnable heartbeatCommand = this::sendAlive; public String toString() { return String.format("Remoting connection %08x to %s of %s", Integer.valueOf(hashCode()), connection.getPeerAddress(), getRemoteConnectionProvider().getConnectionProviderContext().getEndpoint()); } }