/* * JBoss, Home of Professional Open Source. * Copyright 2011, Red Hat, Inc., and individual contributors * as indicated by the @author tags. See the copyright.txt file in the * distribution for a full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jboss.remoting3.remote; import static java.lang.Math.min; import static org.jboss.remoting3._private.Messages.log; import static org.jboss.remoting3._private.Messages.server; import java.io.IOException; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.security.Principal; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import org.jboss.remoting3.RemotingOptions; import org.jboss.remoting3.Version; import org.jboss.remoting3.spi.ConnectionProviderContext; import org.wildfly.security.auth.principal.AnonymousPrincipal; import org.wildfly.security.auth.server.SaslAuthenticationFactory; import org.wildfly.security.auth.server.SecurityIdentity; import org.wildfly.security.sasl.WildFlySasl; import org.wildfly.security.sasl.util.PropertiesSaslServerFactory; import org.wildfly.security.sasl.util.ProtocolSaslServerFactory; import org.wildfly.security.sasl.util.SSLSaslServerFactory; import org.wildfly.security.sasl.util.ServerNameSaslServerFactory; import org.wildfly.security.ssl.SSLUtils; import org.xnio.Buffers; import org.xnio.ChannelListener; import org.xnio.OptionMap; import org.xnio.Options; import org.xnio.Pooled; import org.xnio.Property; import org.xnio.Sequence; import org.xnio.channels.Channels; import org.xnio.channels.SslChannel; import org.xnio.conduits.ConduitStreamSourceChannel; import org.xnio.sasl.SaslUtils; import org.xnio.sasl.SaslWrapper; /** * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a> */ @SuppressWarnings("deprecation") final class ServerConnectionOpenListener implements ChannelListener<ConduitStreamSourceChannel> { private final RemoteConnection connection; private final ConnectionProviderContext connectionProviderContext; private final SaslAuthenticationFactory saslAuthenticationFactory; private final OptionMap optionMap; private final AtomicInteger retryCount = new AtomicInteger(); private final String serverName; ServerConnectionOpenListener(final RemoteConnection connection, final ConnectionProviderContext connectionProviderContext, final SaslAuthenticationFactory saslAuthenticationFactory, final OptionMap optionMap) { this.connection = connection; this.connectionProviderContext = connectionProviderContext; this.saslAuthenticationFactory = saslAuthenticationFactory; this.optionMap = optionMap; if (optionMap.contains(RemotingOptions.SERVER_NAME)) { serverName = optionMap.get(RemotingOptions.SERVER_NAME); } else { serverName = connection.getLocalAddress().getHostName(); } } public void handleEvent(final ConduitStreamSourceChannel channel) { final Pooled<ByteBuffer> pooled = connection.allocate(); boolean ok = false; try { ByteBuffer sendBuffer = pooled.getResource(); sendBuffer.put(Protocol.GREETING); ProtocolUtils.writeString(sendBuffer, Protocol.GRT_SERVER_NAME, serverName); sendBuffer.flip(); connection.setReadListener(new Initial(), true); connection.send(pooled); ok = true; return; } catch (BufferUnderflowException | BufferOverflowException e) { connection.handleException(log.invalidMessage(connection)); return; } finally { if (! ok) pooled.free(); } } private void saslDispose(final SaslServer saslServer) { if (saslServer != null) { try { saslServer.dispose(); } catch (SaslException e) { server.trace("Failure disposing of SaslServer", e); } } } final class Initial implements ChannelListener<ConduitStreamSourceChannel> { private boolean starttls; private Set<String> allowedMechanisms; private int version; private int channelsIn = 40; private int channelsOut = 40; private String remoteEndpointName; private int behavior = Protocol.BH_FAULTY_MSG_SIZE; private boolean authCap; Initial() { // Calculate our capabilities version = Protocol.VERSION; } void initialiseCapabilities() { final SslChannel sslChannel = connection.getSslChannel(); final boolean channelSecure = sslChannel != null && Channels.getOption(sslChannel, Options.SECURE, false); starttls = ! (sslChannel == null || channelSecure); final Set<String> foundMechanisms = new LinkedHashSet<String>(); boolean enableExternal = false; try { // only enable EXTERNAL if there is an external auth layer SSLSession sslSession; if (sslChannel != null && (sslSession = sslChannel.getSslSession()) != null) { connection.setIdentity((SecurityIdentity) sslSession.getValue(SSLUtils.SSL_SESSION_IDENTITY_KEY)); final Principal principal = sslSession.getPeerPrincipal(); // only enable EXTERNAL if there's a peer principal (else it's just ANONYMOUS) if (principal != null) { enableExternal = true; } else { server.trace("No EXTERNAL mechanism due to lack of peer principal"); } } else { server.trace("No EXTERNAL mechanism due to lack of SSL"); } } catch (SSLPeerUnverifiedException e) { server.trace("No EXTERNAL mechanism due to unverified SSL peer"); } int cnt = 0; for (String mechName : saslAuthenticationFactory.getMechanismNames()) { if (foundMechanisms.contains(mechName)) { server.tracef("Excluding repeated occurrence of mechanism %s", mechName); } else if (! enableExternal && mechName.equals("EXTERNAL")) { server.trace("Excluding EXTERNAL due to prior config"); } else { server.tracef("Added mechanism %s", mechName); foundMechanisms.add(mechName); cnt ++; } } retryCount.set(cnt); // No need to re-order as an initial order was not passed in. this.allowedMechanisms = foundMechanisms; } public void handleEvent(final ConduitStreamSourceChannel channel) { final Pooled<ByteBuffer> message; try { message = connection.getMessageReader().getMessage(); } catch (IOException e) { connection.handleException(e); return; } if (message == MessageReader.EOF_MARKER) { log.trace("Received connection end-of-stream"); connection.handlePreAuthCloseRequest(); return; } if (message == null) { return; } boolean free = true; try { final ByteBuffer receiveBuffer = message.getResource(); server.tracef("Received %s", receiveBuffer); final byte msgType = receiveBuffer.get(); switch (msgType) { case Protocol.CONNECTION_CLOSE: { server.trace("Server received connection close request"); connection.handlePreAuthCloseRequest(); return; } case Protocol.CONNECTION_ALIVE: { server.trace("Server received connection alive"); connection.sendAliveResponse(); return; } case Protocol.CONNECTION_ALIVE_ACK: { server.trace("Server received connection alive ack"); return; } case Protocol.CAPABILITIES: { server.trace("Server received capabilities request"); handleClientCapabilities(receiveBuffer); sendCapabilities(); return; } case Protocol.STARTTLS: { server.tracef("Server received STARTTLS request"); final Pooled<ByteBuffer> pooled = connection.allocate(); boolean ok = false; try { ByteBuffer sendBuffer = pooled.getResource(); sendBuffer.put(starttls ? Protocol.STARTTLS : Protocol.NAK); sendBuffer.flip(); connection.send(pooled); ok = true; if (starttls) { connection.send(RemoteConnection.STARTTLS_SENTINEL); } connection.setReadListener(new Initial(), true); return; } finally { if (! ok) pooled.free(); } } case Protocol.AUTH_REQUEST: { server.tracef("Server received authentication request"); if (retryCount.getAndDecrement() <= 0) { // no more tries left connection.handleException(new SaslException("Too many authentication failures; connection terminated"), false); return; } final String mechName; if (version < 1) { mechName = Buffers.getModifiedUtf8(receiveBuffer); } else { mechName = ProtocolUtils.readString(receiveBuffer); } final String protocol = optionMap.get(RemotingOptions.SASL_PROTOCOL, RemotingOptions.DEFAULT_SASL_PROTOCOL); final Map<String, String> saslProperties = getSaslProperties(optionMap); SaslServer saslServer; try { saslServer = saslAuthenticationFactory.createMechanism(mechName, saslServerFactory -> { saslServerFactory = "EXTERNAL".equals(mechName) ? new SSLSaslServerFactory(saslServerFactory, () -> connection.getSslChannel().getSslSession()) : saslServerFactory; saslServerFactory = new ServerNameSaslServerFactory(saslServerFactory, serverName); saslServerFactory = new ProtocolSaslServerFactory(saslServerFactory, protocol); saslServerFactory = saslProperties != null ? new PropertiesSaslServerFactory(saslServerFactory, saslProperties) : saslServerFactory; return saslServerFactory; }); } catch (SaslException e) { server.trace("Unable to create SaslServer", e); saslServer = null; } if (saslServer == null) { rejectAuthentication(mechName); return; } connection.getMessageReader().suspendReads(); connection.getExecutor().execute(new AuthStepRunnable(true, saslServer, message, remoteEndpointName, behavior, channelsIn, channelsOut, authCap, null)); free = false; return; } default: { server.unknownProtocolId(msgType); connection.handleException(log.invalidMessage(connection)); break; } } } catch (BufferUnderflowException | BufferOverflowException e) { connection.handleException(log.invalidMessage(connection)); return; } finally { if (free) message.free(); } } private Map<String, String> getSaslProperties(final OptionMap optionMap) { Map<String, String> saslProperties = null; final Sequence<Property> value = optionMap.get(Options.SASL_PROPERTIES); if (value != null) { saslProperties = new HashMap<>(value.size()); for (Property property : value) { saslProperties.put(property.getKey(), (String) property.getValue()); } } return saslProperties; } void rejectAuthentication(String mechName) { // reject log.rejectedInvalidMechanism(mechName); final Pooled<ByteBuffer> pooled = connection.allocate(); boolean ok = false; try { final ByteBuffer sendBuffer = pooled.getResource(); sendBuffer.put(Protocol.AUTH_REJECTED); sendBuffer.flip(); connection.send(pooled); ok = true; } finally { if (! ok) pooled.free(); } } void handleClientCapabilities(final ByteBuffer receiveBuffer) { boolean useDefaultChannels = true; int channelsIn = 40; int channelsOut = 40; boolean authCap = false; while (receiveBuffer.hasRemaining()) { final byte type = receiveBuffer.get(); final int len = receiveBuffer.get() & 0xff; final ByteBuffer data = Buffers.slice(receiveBuffer, len); switch (type) { case Protocol.CAP_VERSION: { final byte version = data.get(); server.tracef("Server received capability: version %d", Integer.valueOf(version & 0xff)); this.version = min(Protocol.VERSION, version & 0xff); break; } case Protocol.CAP_ENDPOINT_NAME: { remoteEndpointName = Buffers.getModifiedUtf8(data); server.tracef("Server received capability: remote endpoint name \"%s\"", remoteEndpointName); break; } case Protocol.CAP_MESSAGE_CLOSE: { behavior |= Protocol.BH_MESSAGE_CLOSE; // remote side must be >= 3.2.11.GA // but, we'll assume it's >= 3.2.14.GA because no AS or EAP release included 3.2.8.SP1 < x < 3.2.14.GA behavior &= ~Protocol.BH_FAULTY_MSG_SIZE; server.tracef("Server received capability: message close protocol supported"); break; } case Protocol.CAP_VERSION_STRING: { // remote side must be >= 3.2.16.GA behavior &= ~Protocol.BH_FAULTY_MSG_SIZE; final String remoteVersionString = Buffers.getModifiedUtf8(data); server.tracef("Server received capability: remote version is \"%s\"", remoteVersionString); break; } case Protocol.CAP_CHANNELS_IN: { useDefaultChannels = false; // their channels in is our channels out channelsOut = ProtocolUtils.readIntData(data, len); server.tracef("Server received capability: remote channels in is \"%d\"", channelsOut); break; } case Protocol.CAP_CHANNELS_OUT: { useDefaultChannels = false; // their channels out is our channels in channelsIn = ProtocolUtils.readIntData(data, len); server.tracef("Server received capability: remote channels out is \"%d\"", channelsIn); break; } case Protocol.CAP_AUTHENTICATION: { authCap = true; server.trace("Server received capability: authentication service"); break; } default: { server.tracef("Server received unknown capability %02x", Integer.valueOf(type & 0xff)); // unknown, skip it for forward compatibility. break; } } } if (! useDefaultChannels) { this.channelsIn = channelsIn; this.channelsOut = channelsOut; } this.authCap = authCap; } void sendCapabilities() { if (allowedMechanisms == null) { initialiseCapabilities(); } final Pooled<ByteBuffer> pooled = connection.allocate(); boolean ok = false; try { ByteBuffer sendBuffer = pooled.getResource(); sendBuffer.put(Protocol.CAPABILITIES); ProtocolUtils.writeByte(sendBuffer, Protocol.CAP_VERSION, version); final String localEndpointName = connectionProviderContext.getEndpoint().getName(); if (localEndpointName != null) { // don't send a name if we're anonymous ProtocolUtils.writeString(sendBuffer, Protocol.CAP_ENDPOINT_NAME, localEndpointName); } if (starttls) { ProtocolUtils.writeEmpty(sendBuffer, Protocol.CAP_STARTTLS); } for (String mechName : allowedMechanisms) { ProtocolUtils.writeString(sendBuffer, Protocol.CAP_SASL_MECH, mechName); } ProtocolUtils.writeEmpty(sendBuffer, Protocol.CAP_MESSAGE_CLOSE); ProtocolUtils.writeString(sendBuffer, Protocol.CAP_VERSION_STRING, Version.getVersionString()); ProtocolUtils.writeInt(sendBuffer, Protocol.CAP_CHANNELS_IN, optionMap.get(RemotingOptions.MAX_INBOUND_CHANNELS, RemotingOptions.DEFAULT_MAX_INBOUND_CHANNELS)); ProtocolUtils.writeInt(sendBuffer, Protocol.CAP_CHANNELS_OUT, optionMap.get(RemotingOptions.MAX_OUTBOUND_CHANNELS, RemotingOptions.DEFAULT_MAX_OUTBOUND_CHANNELS)); ProtocolUtils.writeEmpty(sendBuffer, Protocol.CAP_AUTHENTICATION); sendBuffer.flip(); connection.send(pooled); ok = true; return; } finally { if (! ok) pooled.free(); } } } final class AuthStepRunnable implements Runnable { private final boolean isInitial; private final SaslServer saslServer; private final Pooled<ByteBuffer> buffer; private final String remoteEndpointName; private final int behavior; private final int maxInboundChannels; private final int maxOutboundChannels; private final boolean authCap; private final Set<String> offeredMechanisms; AuthStepRunnable(final boolean isInitial, final SaslServer saslServer, final Pooled<ByteBuffer> buffer, final String remoteEndpointName, final int behavior, final int maxInboundChannels, final int maxOutboundChannels, final boolean authCap, final Set<String> offeredMechanisms) { this.isInitial = isInitial; this.saslServer = saslServer; this.buffer = buffer; this.remoteEndpointName = remoteEndpointName; this.behavior = behavior; this.maxInboundChannels = maxInboundChannels; this.maxOutboundChannels = maxOutboundChannels; this.authCap = authCap; this.offeredMechanisms = offeredMechanisms; } @Override public void run() { boolean ok = false; boolean close = false; try { final Pooled<ByteBuffer> pooled = connection.allocate(); try { final ByteBuffer sendBuffer = pooled.getResource(); int p = sendBuffer.position(); try { sendBuffer.put(Protocol.AUTH_COMPLETE); if (SaslUtils.evaluateResponse(saslServer, sendBuffer, buffer.getResource())) { server.tracef("Server sending authentication complete"); connectionProviderContext.accept(connectionContext -> { final Object qop = saslServer.getNegotiatedProperty(Sasl.QOP); if (!isInitial && ("auth-int".equals(qop) || "auth-conf".equals(qop))) { connection.setSaslWrapper(SaslWrapper.create(saslServer)); } final RemoteConnectionHandler connectionHandler = new RemoteConnectionHandler( connectionContext, connection, maxInboundChannels, maxOutboundChannels, AnonymousPrincipal.getInstance(), remoteEndpointName, behavior, authCap, offeredMechanisms); connection.getRemoteConnectionProvider().addConnectionHandler(connectionHandler); final SecurityIdentity identity = (SecurityIdentity) saslServer.getNegotiatedProperty(WildFlySasl.SECURITY_IDENTITY); connection.setIdentity(identity == null ? saslAuthenticationFactory.getSecurityDomain().getAnonymousSecurityIdentity() : identity); connection.setReadListener(new RemoteReadListener(connectionHandler, connection), false); return connectionHandler; }, saslAuthenticationFactory); } else { server.tracef("Server sending authentication challenge"); sendBuffer.put(p, Protocol.AUTH_CHALLENGE); if (isInitial) { connection.setReadListener(new Authentication(saslServer, remoteEndpointName, behavior, maxInboundChannels, maxOutboundChannels, authCap, offeredMechanisms), false); } } } catch (Throwable e) { server.tracef(e, "Server sending authentication rejected"); sendBuffer.put(p, Protocol.AUTH_REJECTED); saslDispose(saslServer); if (isInitial) { if (retryCount.decrementAndGet() <= 0) { close = true; } } else { connection.setReadListener(new Initial(), false); } } sendBuffer.flip(); connection.send(pooled, close); ok = true; connection.getMessageReader().resumeReads(); return; } finally { if (!ok) { pooled.free(); } } } finally { buffer.free(); } } } final class Authentication implements ChannelListener<ConduitStreamSourceChannel> { private final SaslServer saslServer; private final String remoteEndpointName; private final int behavior; private final int maxInboundChannels; private final int maxOutboundChannels; private final boolean authCap; private final Set<String> offeredMechanisms; Authentication(final SaslServer saslServer, final String remoteEndpointName, final int behavior, final int maxInboundChannels, final int maxOutboundChannels, final boolean authCap, final Set<String> offeredMechanisms) { this.saslServer = saslServer; this.remoteEndpointName = remoteEndpointName; this.behavior = behavior; this.maxInboundChannels = maxInboundChannels; this.maxOutboundChannels = maxOutboundChannels; this.authCap = authCap; this.offeredMechanisms = offeredMechanisms; } public void handleEvent(final ConduitStreamSourceChannel channel) { final Pooled<ByteBuffer> message; try { message = connection.getMessageReader().getMessage(); } catch (IOException e) { connection.handleException(e); saslDispose(saslServer); return; } if (message == MessageReader.EOF_MARKER) { log.trace("Received connection end-of-stream"); connection.handlePreAuthCloseRequest(); saslDispose(saslServer); return; } if (message == null) { return; } boolean free = true; try { final ByteBuffer buffer = message.getResource(); server.tracef("Received %s", buffer); final byte msgType = buffer.get(); switch (msgType) { case Protocol.CONNECTION_CLOSE: { server.trace("Server received connection close request"); connection.handlePreAuthCloseRequest(); saslDispose(saslServer); return; } case Protocol.AUTH_RESPONSE: { server.tracef("Server received authentication response"); connection.getMessageReader().suspendReads(); connection.getExecutor().execute(new AuthStepRunnable(false, saslServer, message, remoteEndpointName, behavior, maxInboundChannels, maxOutboundChannels, authCap, offeredMechanisms)); free = false; return; } case Protocol.CAPABILITIES: { server.trace("Server received capabilities request (cancelling authentication)"); saslDispose(saslServer); final Initial initial = new Initial(); connection.setReadListener(initial, true); initial.handleClientCapabilities(buffer); initial.sendCapabilities(); return; } default: { server.unknownProtocolId(msgType); connection.handleException(log.invalidMessage(connection)); saslDispose(saslServer); break; } } } catch (BufferUnderflowException | BufferOverflowException e) { connection.handleException(log.invalidMessage(connection)); saslDispose(saslServer); return; } finally { if (free) message.free(); } } } }