/*
* JBoss, Home of Professional Open Source
* Copyright 2011, JBoss Inc., and individual contributors as indicated
* by the @authors tag. See the copyright.txt in the distribution for a
* full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.jboss.remoting3.remote;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import org.jboss.logging.Logger;
import org.jboss.remoting3.RemotingOptions;
import org.jboss.remoting3._private.Messages;
import org.jboss.remoting3.spi.ConnectionHandlerFactory;
import org.wildfly.security.auth.server.SecurityIdentity;
import org.xnio.Buffers;
import org.xnio.ByteBufferPool;
import org.xnio.ChannelListener;
import org.xnio.Connection;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.Pooled;
import org.xnio.Result;
import org.xnio.StreamConnection;
import org.xnio.XnioExecutor;
import org.xnio.channels.SslChannel;
import org.xnio.conduits.ConduitStreamSinkChannel;
import org.xnio.conduits.ConduitStreamSourceChannel;
import org.xnio.sasl.SaslWrapper;
/**
* @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
*/
final class RemoteConnection {
static final Pooled<ByteBuffer> STARTTLS_SENTINEL = Buffers.emptyPooledByteBuffer();
private static final String FQCN = RemoteConnection.class.getName();
private final StreamConnection connection;
private final MessageReader messageReader;
private final SslChannel sslChannel;
private final OptionMap optionMap;
private final RemoteWriteListener writeListener = new RemoteWriteListener();
private final Executor executor;
private final int heartbeatInterval;
private volatile Result<ConnectionHandlerFactory> result;
private volatile SaslWrapper saslWrapper;
private volatile SecurityIdentity identity;
private final RemoteConnectionProvider remoteConnectionProvider;
RemoteConnection(final StreamConnection connection, final SslChannel sslChannel, final OptionMap optionMap, final RemoteConnectionProvider remoteConnectionProvider) {
this.connection = connection;
this.messageReader = new MessageReader(connection.getSourceChannel(), writeListener.queue);
this.sslChannel = sslChannel;
this.optionMap = optionMap;
heartbeatInterval = optionMap.get(RemotingOptions.HEARTBEAT_INTERVAL, RemotingOptions.DEFAULT_HEARTBEAT_INTERVAL);
Messages.conn.tracef("Initialized connection from %s to %s with options %s", connection.getPeerAddress(), connection.getLocalAddress(), optionMap);
this.executor = remoteConnectionProvider.getExecutor();
this.remoteConnectionProvider = remoteConnectionProvider;
}
Pooled<ByteBuffer> allocate() {
return Buffers.globalPooledWrapper(ByteBufferPool.MEDIUM_DIRECT.allocate());
}
void setReadListener(ChannelListener<ConduitStreamSourceChannel> listener, final boolean resume) {
Messages.log.logf(RemoteConnection.class.getName(), Logger.Level.TRACE, null, "Setting read listener to %s", listener);
messageReader.setReadListener(listener);
if (listener != null && resume) {
messageReader.resumeReads();
}
}
RemoteConnectionProvider getRemoteConnectionProvider() {
return remoteConnectionProvider;
}
Result<ConnectionHandlerFactory> getResult() {
return result;
}
void setResult(final Result<ConnectionHandlerFactory> result) {
this.result = result;
}
void handleException(IOException e) {
handleException(e, true);
}
void handleException(IOException e, boolean log) {
Messages.conn.logf(RemoteConnection.class.getName(), Logger.Level.TRACE, e, "Connection error detail");
if (log) {
Messages.conn.connectionError(e);
}
final XnioExecutor.Key key = writeListener.heartKey;
if (key != null) {
key.remove();
}
synchronized (getLock()) {
IoUtils.safeClose(connection);
}
final Result<ConnectionHandlerFactory> result = this.result;
if (result != null) {
result.setException(e);
this.result = null;
}
}
void send(final Pooled<ByteBuffer> pooled) {
writeListener.send(pooled, false);
}
void send(final Pooled<ByteBuffer> pooled, boolean close) {
writeListener.send(pooled, close);
}
void shutdownWrites() {
writeListener.shutdownWrites();
}
OptionMap getOptionMap() {
return optionMap;
}
MessageReader getMessageReader() {
return messageReader;
}
RemoteWriteListener getWriteListener() {
return writeListener;
}
public Executor getExecutor() {
return executor;
}
public SslChannel getSslChannel() {
return sslChannel;
}
SaslWrapper getSaslWrapper() {
return saslWrapper;
}
void setSaslWrapper(final SaslWrapper saslWrapper) {
this.saslWrapper = saslWrapper;
}
void handlePreAuthCloseRequest() {
try {
terminateHeartbeat();
synchronized (getLock()) {
connection.close();
}
} catch (IOException e) {
Messages.conn.debug("Error closing remoting channel", e);
}
}
void sendAlive() {
Messages.conn.trace("Sending connection alive");
final Pooled<ByteBuffer> pooled = allocate();
boolean ok = false;
try {
final ByteBuffer buffer = pooled.getResource();
buffer.put(Protocol.CONNECTION_ALIVE);
buffer.limit(80);
Buffers.addRandom(buffer);
buffer.flip();
send(pooled);
ok = true;
messageReader.wakeupReads();
} finally {
if (! ok) pooled.free();
}
}
void sendAliveResponse() {
Messages.conn.trace("Sending connection alive ack");
final Pooled<ByteBuffer> pooled = allocate();
boolean ok = false;
try {
final ByteBuffer buffer = pooled.getResource();
buffer.put(Protocol.CONNECTION_ALIVE_ACK);
buffer.limit(80);
Buffers.addRandom(buffer);
buffer.flip();
send(pooled);
ok = true;
} finally {
if (! ok) pooled.free();
}
}
void terminateHeartbeat() {
final XnioExecutor.Key key = writeListener.heartKey;
if (key != null) {
key.remove();
}
}
Object getLock() {
return writeListener.queue;
}
SecurityIdentity getIdentity() {
return identity;
}
void setIdentity(final SecurityIdentity identity) {
this.identity = identity;
}
InetSocketAddress getPeerAddress() {
return connection.getPeerAddress(InetSocketAddress.class);
}
InetSocketAddress getLocalAddress() {
return connection.getLocalAddress(InetSocketAddress.class);
}
Connection getConnection() {
return connection;
}
final class RemoteWriteListener implements ChannelListener<ConduitStreamSinkChannel> {
private final Queue<Pooled<ByteBuffer>> queue = new ArrayDeque<Pooled<ByteBuffer>>();
private XnioExecutor.Key heartKey;
private boolean closed;
private ByteBuffer headerBuffer = ByteBuffer.allocateDirect(4);
private final ByteBuffer[] cachedArray = new ByteBuffer[] { headerBuffer, null };
RemoteWriteListener() {
}
public void handleEvent(final ConduitStreamSinkChannel channel) {
final ByteBuffer[] cachedArray = this.cachedArray;
synchronized (queue) {
Pooled<ByteBuffer> pooled;
final Queue<Pooled<ByteBuffer>> queue = this.queue;
try {
ByteBuffer buffer = cachedArray[1];
if (buffer != null) {
channel.write(cachedArray);
if (buffer.hasRemaining()) {
return;
}
}
cachedArray[1] = null;
while ((pooled = queue.peek()) != null) {
buffer = pooled.getResource();
if (buffer.hasRemaining()) { // no empty messages
headerBuffer.putInt(0, buffer.remaining());
headerBuffer.position(0);
cachedArray[1] = buffer;
final long res = channel.write(cachedArray);
Messages.conn.tracef("Sent %d bytes", res);
if (buffer.hasRemaining()) {
// try again later
return;
} else {
cachedArray[1] = null;
queue.poll().free();
}
} else {
if (pooled == STARTTLS_SENTINEL) {
if (channel.flush()) {
Messages.conn.trace("Flushed channel");
final SslChannel sslChannel = getSslChannel();
assert sslChannel != null; // because STARTTLS would be false in this case
sslChannel.startHandshake();
} else {
// try again later
Messages.conn.trace("Flush stalled");
return;
}
}
// otherwise skip other empty message rather than try and write it
queue.poll().free();
}
}
if (channel.flush()) {
Messages.conn.trace("Flushed channel");
if (closed) {
terminateHeartbeat();
// End of queue reached; shut down and try to flush the remainder
channel.shutdownWrites();
if (channel.flush()) {
Messages.conn.trace("Shut down writes on channel");
return;
}
// either this is successful and no more notifications will come, or not and it will be retried
// either way we're done here
return;
} else {
if (heartbeatInterval != Integer.MAX_VALUE) {
this.heartKey = channel.getWriteThread().executeAfter(heartbeatCommand, heartbeatInterval, TimeUnit.MILLISECONDS);
}
}
channel.suspendWrites();
}
} catch (IOException e) {
handleException(e, false);
while ((pooled = queue.poll()) != null) {
pooled.free();
}
}
// else try again later
}
}
public void shutdownWrites() {
synchronized (queue) {
closed = true;
terminateHeartbeat();
final ConduitStreamSinkChannel sinkChannel = connection.getSinkChannel();
try {
if (! queue.isEmpty()) {
sinkChannel.resumeWrites();
return;
}
sinkChannel.shutdownWrites();
if (! sinkChannel.flush()) {
sinkChannel.resumeWrites();
return;
}
Messages.conn.logf(FQCN, Logger.Level.TRACE, null, "Shut down writes on channel");
} catch (IOException e) {
handleException(e, false);
Pooled<ByteBuffer> unqueued;
while ((unqueued = queue.poll()) != null) {
unqueued.free();
}
}
}
}
public void send(final Pooled<ByteBuffer> pooled, final boolean close) {
connection.getIoThread().execute(() -> {
synchronized (queue) {
XnioExecutor.Key heartKey1 = RemoteWriteListener.this.heartKey;
if (heartKey1 != null) heartKey1.remove();
if (closed) { pooled.free(); return; }
if (close) { closed = true; }
boolean free = true;
try {
final SaslWrapper wrapper = saslWrapper;
if (wrapper != null) {
final ByteBuffer buffer = pooled.getResource();
final ByteBuffer source = buffer.duplicate();
buffer.clear();
wrapper.wrap(buffer, source);
buffer.flip();
}
final boolean empty = queue.isEmpty();
queue.add(pooled);
free = false;
if (empty) {
connection.getSinkChannel().resumeWrites();
}
} catch (IOException e) {
handleException(e, false);
Pooled<ByteBuffer> unqueued;
while ((unqueued = queue.poll()) != null) {
unqueued.free();
}
} finally {
if (free) {
pooled.free();
}
}
}
});
}
}
private final Runnable heartbeatCommand = this::sendAlive;
public String toString() {
return String.format("Remoting connection %08x to %s of %s", Integer.valueOf(hashCode()), connection.getPeerAddress(), getRemoteConnectionProvider().getConnectionProviderContext().getEndpoint());
}
}