/* * JBoss, Home of Professional Open Source * Copyright 2011, JBoss Inc., and individual contributors as indicated * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jboss.remoting3.remote; import java.io.IOException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.channels.Channel; import java.nio.charset.StandardCharsets; import org.jboss.remoting3.OpenListener; import org.jboss.remoting3.RemotingOptions; import org.jboss.remoting3.ServiceOpenException; import org.jboss.remoting3.spi.ConnectionHandlerContext; import org.jboss.remoting3.spi.RegisteredService; import org.jboss.remoting3.spi.SpiUtils; import org.xnio.Buffers; import org.xnio.ChannelListener; import org.xnio.IoUtils; import org.xnio.OptionMap; import org.xnio.Pooled; import org.xnio.conduits.ConduitStreamSourceChannel; import org.xnio.sasl.SaslWrapper; import static org.jboss.remoting3._private.Messages.log; /** * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a> */ final class RemoteReadListener implements ChannelListener<ConduitStreamSourceChannel> { private final RemoteConnectionHandler handler; private final RemoteConnection connection; RemoteReadListener(final RemoteConnectionHandler handler, final RemoteConnection connection) { synchronized (connection.getLock()) { connection.getConnection().getCloseSetter().set((ChannelListener<Channel>) channel -> connection.getExecutor().execute(() -> { handler.handleConnectionClose(); handler.closeComplete(); })); } this.handler = handler; this.connection = connection; } public void handleEvent(final ConduitStreamSourceChannel channel) { SaslWrapper saslWrapper = connection.getSaslWrapper(); final Object lock = connection.getLock(); final MessageReader messageReader = connection.getMessageReader(); try { Pooled<ByteBuffer> message = null; ByteBuffer buffer = null; try { for (;;) try { boolean exit = false; message = messageReader.getMessage(); if (message == MessageReader.EOF_MARKER) { log.trace("Received connection end-of-stream"); exit = true; } else if (message == null) { log.trace("No message ready; returning"); return; } if (exit) { messageReader.shutdownReads(); handler.receiveCloseRequest(); return; } buffer = message.getResource(); if (saslWrapper != null) { final ByteBuffer source = buffer.duplicate(); buffer.clear(); saslWrapper.unwrap(buffer, source); buffer.flip(); } final byte protoId = buffer.get(); try { switch (protoId) { case Protocol.CONNECTION_ALIVE: { log.trace("Received connection alive"); connection.sendAliveResponse(); return; } case Protocol.CONNECTION_ALIVE_ACK: { log.trace("Received connection alive ack"); return; } case Protocol.CONNECTION_CLOSE: { log.trace("Received connection close request"); handler.receiveCloseRequest(); // do not return now so we can read once more, // thus making sure we are not skipping a // receive equal to -1 break; } case Protocol.CHANNEL_OPEN_REQUEST: { log.trace("Received channel open request"); int channelId = buffer.getInt() ^ 0x80000000; int requestedInboundWindow = Integer.MAX_VALUE; int requestedInboundMessages = 0xffff; int requestedOutboundWindow = Integer.MAX_VALUE; int requestedOutboundMessages = 0xffff; long requestedInboundMessageSize = Long.MAX_VALUE; long requestedOutboundMessageSize = Long.MAX_VALUE; // parse out request int b; String serviceType = null; OUT: for (;;) { b = buffer.get() & 0xff; switch (b) { case Protocol.O_END: break OUT; case Protocol.O_SERVICE_NAME: { serviceType = ProtocolUtils.readString(buffer); break; } case Protocol.O_MAX_INBOUND_MSG_WINDOW_SIZE: { requestedOutboundWindow = Math.min(requestedOutboundWindow, ProtocolUtils.readInt(buffer)); break; } case Protocol.O_MAX_INBOUND_MSG_COUNT: { requestedOutboundMessages = Math.min(requestedOutboundMessages, ProtocolUtils.readUnsignedShort(buffer)); break; } case Protocol.O_MAX_OUTBOUND_MSG_WINDOW_SIZE: { requestedInboundWindow = Math.min(requestedInboundWindow, ProtocolUtils.readInt(buffer)); break; } case Protocol.O_MAX_OUTBOUND_MSG_COUNT: { requestedInboundMessages = Math.min(requestedInboundMessages, ProtocolUtils.readUnsignedShort(buffer)); break; } case Protocol.O_MAX_INBOUND_MSG_SIZE: { requestedOutboundMessageSize = Math.min(requestedOutboundMessageSize, ProtocolUtils.readLong(buffer)); break; } case Protocol.O_MAX_OUTBOUND_MSG_SIZE: { requestedInboundMessageSize = Math.min(requestedInboundMessageSize, ProtocolUtils.readLong(buffer)); break; } default: { Buffers.skip(buffer, buffer.get() & 0xff); break; } } } if ((channelId & 0x80000000) != 0) { // invalid channel ID, original should have had MSB=1 and thus the complement should be MSB=0 refuseService(channelId, "Invalid channel ID"); break; } if (serviceType == null) { // invalid service reply refuseService(channelId, "Missing service name"); break; } final RegisteredService registeredService = handler.getConnectionContext().getRegisteredService(serviceType); if (registeredService == null) { refuseService(channelId, "Unknown service name"); break; } final OptionMap serviceOptionMap = registeredService.getOptionMap(); final int outboundWindowOptionValue = serviceOptionMap.get(RemotingOptions.TRANSMIT_WINDOW_SIZE, RemotingOptions.INCOMING_CHANNEL_DEFAULT_TRANSMIT_WINDOW_SIZE); final int outboundMessagesOptionValue = serviceOptionMap.get(RemotingOptions.MAX_OUTBOUND_MESSAGES, RemotingOptions.INCOMING_CHANNEL_DEFAULT_MAX_OUTBOUND_MESSAGES); final int inboundWindowOptionValue = serviceOptionMap.get(RemotingOptions.RECEIVE_WINDOW_SIZE, RemotingOptions.INCOMING_CHANNEL_DEFAULT_RECEIVE_WINDOW_SIZE); final int inboundMessagesOptionValue = serviceOptionMap.get(RemotingOptions.MAX_INBOUND_MESSAGES, RemotingOptions.DEFAULT_MAX_INBOUND_MESSAGES); final long outboundMessageSizeOptionValue = serviceOptionMap.get(RemotingOptions.MAX_OUTBOUND_MESSAGE_SIZE, RemotingOptions.DEFAULT_MAX_OUTBOUND_MESSAGE_SIZE); final long inboundMessageSizeOptionValue = serviceOptionMap.get(RemotingOptions.MAX_INBOUND_MESSAGE_SIZE, RemotingOptions.DEFAULT_MAX_INBOUND_MESSAGE_SIZE); final int outboundWindow = Math.min(requestedOutboundWindow, outboundWindowOptionValue); final int outboundMessages = Math.min(requestedOutboundMessages, outboundMessagesOptionValue); final int inboundWindow = Math.min(requestedInboundWindow, inboundWindowOptionValue); final int inboundMessages = Math.min(requestedInboundMessages, inboundMessagesOptionValue); final long outboundMessageSize = Math.min(requestedOutboundMessageSize, outboundMessageSizeOptionValue); final long inboundMessageSize = Math.min(requestedInboundMessageSize, inboundMessageSizeOptionValue); if (log.isTraceEnabled()) { log.tracef( "Inbound service request for channel %08x is configured as follows:\n" + " outbound window: req %10d, option %10d, grant %10d\n" + " inbound window: req %10d, option %10d, grant %10d\n" + " outbound msgs: req %10d, option %10d, grant %10d\n" + " inbound msgs: req %10d, option %10d, grant %10d\n" + " outbound msgsize: req %19d, option %19d, grant %19d\n" + " inbound msgsize: req %19d, option %19d, grant %19d", Integer.valueOf(channelId), Integer.valueOf(requestedOutboundWindow), Integer.valueOf(outboundWindowOptionValue), Integer.valueOf(outboundWindow), Integer.valueOf(requestedInboundWindow), Integer.valueOf(inboundWindowOptionValue), Integer.valueOf(inboundWindow), Integer.valueOf(requestedOutboundMessages), Integer.valueOf(outboundMessagesOptionValue), Integer.valueOf(outboundMessages), Integer.valueOf(requestedInboundMessages), Integer.valueOf(inboundMessagesOptionValue), Integer.valueOf(inboundMessages), Long.valueOf(requestedOutboundMessageSize), Long.valueOf(outboundMessageSizeOptionValue), Long.valueOf(outboundMessageSize), Long.valueOf(requestedInboundMessageSize), Long.valueOf(inboundMessageSizeOptionValue), Long.valueOf(inboundMessageSize) ); } final OpenListener openListener = registeredService.getOpenListener(); if (! handler.handleInboundChannelOpen()) { // refuse refuseService(channelId, "Channel refused"); break; } boolean ok1 = false; try { // construct the channel RemoteConnectionChannel connectionChannel = new RemoteConnectionChannel(handler, connection, channelId, outboundWindow, inboundWindow, outboundMessages, inboundMessages, outboundMessageSize, inboundMessageSize); RemoteConnectionChannel existing = handler.addChannel(connectionChannel); if (existing != null) { log.tracef("Encountered open request for duplicate %s", existing); // the channel already exists, which means the remote side "forgot" about it or we somehow missed the close message. // the only safe thing to do is to terminate the existing channel. try { refuseService(channelId, "Duplicate ID"); } finally { existing.handleRemoteClose(); } break; } // construct reply Pooled<ByteBuffer> pooledReply = connection.allocate(); boolean ok2 = false; try { ByteBuffer replyBuffer = pooledReply.getResource(); replyBuffer.clear(); replyBuffer.put(Protocol.CHANNEL_OPEN_ACK); replyBuffer.putInt(channelId); ProtocolUtils.writeInt(replyBuffer, Protocol.O_MAX_INBOUND_MSG_WINDOW_SIZE, inboundWindow); ProtocolUtils.writeShort(replyBuffer, Protocol.O_MAX_INBOUND_MSG_COUNT, inboundMessages); if (inboundMessageSize != Long.MAX_VALUE) { ProtocolUtils.writeLong(replyBuffer, Protocol.O_MAX_INBOUND_MSG_SIZE, inboundMessageSize); } ProtocolUtils.writeInt(replyBuffer, Protocol.O_MAX_OUTBOUND_MSG_WINDOW_SIZE, outboundWindow); ProtocolUtils.writeShort(replyBuffer, Protocol.O_MAX_OUTBOUND_MSG_COUNT, outboundMessages); if (outboundMessageSize != Long.MAX_VALUE) { ProtocolUtils.writeLong(replyBuffer, Protocol.O_MAX_OUTBOUND_MSG_SIZE, outboundMessageSize); } replyBuffer.put((byte) 0); replyBuffer.flip(); ok2 = true; // send takes ownership of the buffer connection.send(pooledReply); } finally { if (! ok2) pooledReply.free(); } ok1 = true; // Call the service open listener connection.getExecutor().execute(SpiUtils.getServiceOpenTask(connectionChannel, openListener)); break; } finally { // the inbound channel wasn't open so don't leak the ref count if (! ok1) handler.handleInboundChannelClosed(); } } case Protocol.MESSAGE_DATA: { log.trace("Received message data"); int channelId = buffer.getInt() ^ 0x80000000; RemoteConnectionChannel connectionChannel = handler.getChannel(channelId); if (connectionChannel == null) { // ignore the data log.tracef("Ignoring message data for expired channel"); break; } // protect against double-free if the method fails Pooled<ByteBuffer> messageCopy = message; message = null; buffer = null; connectionChannel.handleMessageData(messageCopy); break; } case Protocol.MESSAGE_WINDOW_OPEN: { log.trace("Received message window open"); int channelId = buffer.getInt() ^ 0x80000000; RemoteConnectionChannel connectionChannel = handler.getChannel(channelId); if (connectionChannel == null) { // ignore log.tracef("Ignoring window open for expired channel"); break; } connectionChannel.handleWindowOpen(message); break; } case Protocol.MESSAGE_CLOSE: { log.trace("Received message async close"); int channelId = buffer.getInt() ^ 0x80000000; RemoteConnectionChannel connectionChannel = handler.getChannel(channelId); if (connectionChannel == null) { break; } connectionChannel.handleAsyncClose(message); break; } case Protocol.CHANNEL_CLOSED: { log.trace("Received channel closed"); int channelId = buffer.getInt() ^ 0x80000000; RemoteConnectionChannel connectionChannel = handler.getChannel(channelId); if (connectionChannel == null) { break; } connectionChannel.handleRemoteClose(); break; } case Protocol.CHANNEL_SHUTDOWN_WRITE: { log.trace("Received channel shutdown write"); int channelId = buffer.getInt() ^ 0x80000000; RemoteConnectionChannel connectionChannel = handler.getChannel(channelId); if (connectionChannel == null) { break; } connectionChannel.handleIncomingWriteShutdown(); break; } case Protocol.CHANNEL_OPEN_ACK: { log.trace("Received channel open ack"); int channelId = buffer.getInt() ^ 0x80000000; if ((channelId & 0x80000000) == 0) { // invalid break; } PendingChannel pendingChannel = handler.removePendingChannel(channelId); if (pendingChannel == null) { // invalid break; } int requestedOutboundWindow = pendingChannel.getOutboundWindowSize(); int requestedInboundWindow = pendingChannel.getInboundWindowSize(); int requestedOutboundMessageCount = pendingChannel.getOutboundMessageCount(); int requestedInboundMessageCount = pendingChannel.getInboundMessageCount(); long requestedOutboundMessageSize = pendingChannel.getOutboundMessageSize(); long requestedInboundMessageSize = pendingChannel.getInboundMessageSize(); int outboundWindow = requestedOutboundWindow; int inboundWindow = requestedInboundWindow; int outboundMessageCount = requestedOutboundMessageCount; int inboundMessageCount = requestedInboundMessageCount; long outboundMessageSize = requestedOutboundMessageSize; long inboundMessageSize = requestedInboundMessageSize; OUT: for (;;) { switch (buffer.get() & 0xff) { case Protocol.O_MAX_INBOUND_MSG_WINDOW_SIZE: { outboundWindow = Math.min(outboundWindow, ProtocolUtils.readInt(buffer)); break; } case Protocol.O_MAX_INBOUND_MSG_COUNT: { outboundMessageCount = Math.min(outboundMessageCount, ProtocolUtils.readUnsignedShort(buffer)); break; } case Protocol.O_MAX_OUTBOUND_MSG_WINDOW_SIZE: { inboundWindow = Math.min(inboundWindow, ProtocolUtils.readInt(buffer)); break; } case Protocol.O_MAX_OUTBOUND_MSG_COUNT: { inboundMessageCount = Math.min(inboundMessageCount, ProtocolUtils.readUnsignedShort(buffer)); break; } case Protocol.O_MAX_INBOUND_MSG_SIZE: { outboundMessageSize = Math.min(outboundMessageSize, ProtocolUtils.readLong(buffer)); break; } case Protocol.O_MAX_OUTBOUND_MSG_SIZE: { inboundMessageSize = Math.min(inboundMessageSize, ProtocolUtils.readLong(buffer)); break; } case Protocol.O_END: { break OUT; } default: { // ignore unknown parameter Buffers.skip(buffer, buffer.get() & 0xff); break; } } } if (log.isTraceEnabled()) { log.tracef( "Inbound service acknowledgement for channel %08x is configured as follows:\n" + " outbound window: req %10d, use %10d\n" + " inbound window: req %10d, use %10d\n" + " outbound msgs: req %10d, use %10d\n" + " inbound msgs: req %10d, use %10d\n" + " outbound msgsize: req %19d, use %19d\n" + " inbound msgsize: req %19d, use %19d", Integer.valueOf(channelId), Integer.valueOf(requestedOutboundWindow), Integer.valueOf(outboundWindow), Integer.valueOf(requestedInboundWindow), Integer.valueOf(inboundWindow), Integer.valueOf(requestedOutboundMessageCount), Integer.valueOf(outboundMessageCount), Integer.valueOf(requestedInboundMessageCount), Integer.valueOf(inboundMessageCount), Long.valueOf(requestedOutboundMessageSize), Long.valueOf(outboundMessageSize), Long.valueOf(requestedInboundMessageSize), Long.valueOf(inboundMessageSize) ); } RemoteConnectionChannel newChannel = new RemoteConnectionChannel(handler, connection, channelId, outboundWindow, inboundWindow, outboundMessageCount, inboundMessageCount, outboundMessageSize, inboundMessageSize); handler.putChannel(newChannel); pendingChannel.getResult().setResult(newChannel); break; } case Protocol.SERVICE_ERROR: { log.trace("Received service error"); int channelId = buffer.getInt() ^ 0x80000000; PendingChannel pendingChannel = handler.removePendingChannel(channelId); if (pendingChannel == null) { // invalid break; } String reason = new String(Buffers.take(buffer), StandardCharsets.UTF_8); pendingChannel.getResult().setException(new ServiceOpenException(reason)); break; } case Protocol.APP_AUTH_REQUEST: { int id = buffer.getInt(); final int length = (buffer.get() - 1 & 0xff) + 1; // range: 1 - 256 final byte[] mechNameBytes = new byte[length]; buffer.get(mechNameBytes); String mechName = new String(mechNameBytes, StandardCharsets.UTF_8); log.tracef("Received authentication request, id %08x, mech %s", id, mechName); ConnectionHandlerContext c = handler.getConnectionContext(); final byte[] saslBytes; if (buffer.hasRemaining()) { saslBytes = new byte[buffer.remaining()]; buffer.get(saslBytes); } else { saslBytes = null; } c.receiveAuthRequest(id, mechName, saslBytes); break; } case Protocol.APP_AUTH_CHALLENGE: { int id = buffer.getInt(); log.tracef("Received authentication challenge, id %08x", id); ConnectionHandlerContext c = handler.getConnectionContext(); final byte[] saslBytes; saslBytes = new byte[buffer.remaining()]; buffer.get(saslBytes); c.receiveAuthChallenge(id, saslBytes); break; } case Protocol.APP_AUTH_RESPONSE: { int id = buffer.getInt(); log.tracef("Received authentication response, id %08x", id); ConnectionHandlerContext c = handler.getConnectionContext(); final byte[] saslBytes; saslBytes = new byte[buffer.remaining()]; buffer.get(saslBytes); c.receiveAuthResponse(id, saslBytes); break; } case Protocol.APP_AUTH_SUCCESS: { int id = buffer.getInt(); log.tracef("Received authentication success, id %08x", id); ConnectionHandlerContext c = handler.getConnectionContext(); final byte[] saslBytes; if (buffer.hasRemaining()) { saslBytes = new byte[buffer.remaining()]; buffer.get(saslBytes); } else { saslBytes = null; } c.receiveAuthSuccess(id, saslBytes); break; } case Protocol.APP_AUTH_REJECT: { int id = buffer.getInt(); log.tracef("Received authentication reject, id %08x", id); ConnectionHandlerContext c = handler.getConnectionContext(); c.receiveAuthReject(id); break; } case Protocol.APP_AUTH_DELETE: { int id = buffer.getInt(); log.tracef("Received authentication delete, id %08x", id); ConnectionHandlerContext c = handler.getConnectionContext(); c.receiveAuthDelete(id); break; } case Protocol.APP_AUTH_DELETE_ACK: { int id = buffer.getInt(); log.tracef("Received authentication delete ack, id %08x", id); ConnectionHandlerContext c = handler.getConnectionContext(); c.receiveAuthDeleteAck(id); break; } default: { log.unknownProtocolId(protoId); break; } } } catch (BufferUnderflowException e) { log.bufferUnderflow(protoId); } } catch (BufferUnderflowException e) { log.bufferUnderflowRaw(); } finally { if (buffer != null) buffer.clear(); } } finally { if (message != null) message.free(); } } catch (IOException e) { connection.handleException(e); synchronized (lock) { IoUtils.safeClose(channel); } } } private void refuseService(final int channelId, final String reason) { if (log.isTraceEnabled()) { log.tracef("Refusing service on channel %08x: %s", Integer.valueOf(channelId), reason); } Pooled<ByteBuffer> pooledReply = connection.allocate(); boolean ok = false; try { ByteBuffer replyBuffer = pooledReply.getResource(); replyBuffer.clear(); replyBuffer.put(Protocol.SERVICE_ERROR); replyBuffer.putInt(channelId); replyBuffer.put(reason.getBytes(StandardCharsets.UTF_8)); replyBuffer.flip(); ok = true; // send takes ownership of the buffer connection.send(pooledReply); } finally { if (! ok) pooledReply.free(); } } }