/*
* JBoss, Home of Professional Open Source.
* Copyright 2016, Red Hat, Inc., and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.jboss.remoting3;
import static java.security.AccessController.doPrivileged;
import static org.jboss.remoting3._private.Messages.log;
import java.io.IOException;
import java.security.Principal;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.UnaryOperator;
import javax.net.ssl.SSLSession;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import org.jboss.remoting3._private.IntIndexHashMap;
import org.jboss.remoting3.spi.ConnectionHandler;
import org.wildfly.common.Assert;
import org.wildfly.security.auth.AuthenticationException;
import org.wildfly.security.auth.client.AuthenticationConfiguration;
import org.wildfly.security.auth.client.AuthenticationContextConfigurationClient;
import org.wildfly.security.auth.client.PeerIdentityContext;
import org.wildfly.security.auth.principal.AnonymousPrincipal;
import org.wildfly.security.sasl.WildFlySasl;
import org.wildfly.security.sasl.util.ProtocolSaslClientFactory;
import org.wildfly.security.sasl.util.ServerNameSaslClientFactory;
import org.xnio.Cancellable;
import org.xnio.FinishedIoFuture;
import org.xnio.FutureResult;
import org.xnio.IoFuture;
/**
* A peer identity context for a connection which supports remote authentication-based identity multiplexing.
*
* @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
*/
public final class ConnectionPeerIdentityContext extends PeerIdentityContext {
private static final byte[] NO_BYTES = new byte[0];
private final ConnectionImpl connection;
private final Collection<String> offeredMechanisms;
private final ConnectionPeerIdentity anonymousIdentity;
private final ConnectionPeerIdentity connectionIdentity;
private final FinishedIoFuture<ConnectionPeerIdentity> connectionIdentityFuture;
private final FinishedIoFuture<ConnectionPeerIdentity> anonymousIdentityFuture;
private final IntIndexHashMap<Authentication> authMap = new IntIndexHashMap<Authentication>(Authentication::getId);
private final ConcurrentHashMap<AuthenticationConfiguration, IoFuture<ConnectionPeerIdentity>> futureAuths = new ConcurrentHashMap<>();
private final UnaryOperator<SaslClientFactory> factoryOperator;
private static final AuthenticationContextConfigurationClient CLIENT = doPrivileged((PrivilegedAction<AuthenticationContextConfigurationClient>) AuthenticationContextConfigurationClient::new);
ConnectionPeerIdentityContext(final ConnectionImpl connection, final Collection<String> offeredMechanisms, String saslProtocol) {
this.connection = connection;
this.offeredMechanisms = offeredMechanisms == null ? Collections.emptySet() : offeredMechanisms;
connectionIdentity = constructIdentity(conf -> new ConnectionPeerIdentity(conf, connection.getPrincipal(), 0, connection));
connectionIdentityFuture = new FinishedIoFuture<>(connectionIdentity);
anonymousIdentity = constructIdentity(conf -> new ConnectionPeerIdentity(conf, AnonymousPrincipal.getInstance(), 1, connection));
anonymousIdentityFuture = new FinishedIoFuture<>(anonymousIdentity);
factoryOperator = factory -> new ProtocolSaslClientFactory(new ServerNameSaslClientFactory(factory, connection.getRemoteEndpointName()), saslProtocol);
}
private static final Object PENDING = new Object();
private static final Object CANCELLED = new Object();
public IoFuture<ConnectionPeerIdentity> authenticateAsync(final AuthenticationConfiguration configuration) {
Assert.checkNotNullParam("configuration", configuration);
if (configuration.equals(connection.getAuthenticationConfiguration())) {
return connectionIdentityFuture;
} else if (CLIENT.getAuthorizationPrincipal(configuration) instanceof AnonymousPrincipal) {
return anonymousIdentityFuture;
}
IoFuture<ConnectionPeerIdentity> ioFuture = futureAuths.get(configuration);
if (ioFuture != null) {
return ioFuture;
}
final FutureResult<ConnectionPeerIdentity> futureResult = new FutureResult<>(connection.getEndpoint().getExecutor());
ioFuture = futureAuths.putIfAbsent(configuration, futureResult.getIoFuture());
if (ioFuture != null) {
return ioFuture;
}
final AtomicReference<Object> statRef = new AtomicReference<>(PENDING);
connection.getEndpoint().getExecutor().execute(() -> {
Object oldVal;
do {
oldVal = statRef.get();
if (oldVal == CANCELLED) {
return;
}
} while (! statRef.compareAndSet(PENDING, Thread.currentThread()));
try {
futureResult.setResult(authenticate(configuration));
} catch (AuthenticationException e) {
futureResult.setException(e);
}
statRef.set(null);
});
futureResult.addCancelHandler(new Cancellable() {
public Cancellable cancel() {
Object oldVal;
do {
oldVal = statRef.get();
if (oldVal == CANCELLED) {
return this;
}
if (oldVal instanceof Thread) {
((Thread) oldVal).interrupt();
return this;
}
} while (! statRef.compareAndSet(PENDING, CANCELLED));
return this;
}
});
return futureResult.getIoFuture();
}
/**
* Perform an authentication.
*
* @param configuration the authentication configuration to use (must not be {@code null})
* @return the peer identity (not {@code null})
* @throws AuthenticationException if the authentication attempt failed
*/
public ConnectionPeerIdentity authenticate(final AuthenticationConfiguration configuration) throws AuthenticationException {
if (configuration.equals(connection.getAuthenticationConfiguration())) {
return connectionIdentity;
} else if (CLIENT.getAuthorizationPrincipal(configuration) instanceof AnonymousPrincipal) {
return anonymousIdentity;
}
IoFuture<ConnectionPeerIdentity> ioFuture = futureAuths.get(configuration);
if (ioFuture == null) {
FutureResult<ConnectionPeerIdentity> futureResult = new FutureResult<>(connection.getEndpoint().getExecutor());
final IoFuture<ConnectionPeerIdentity> appearing = futureAuths.putIfAbsent(configuration, futureResult.getIoFuture());
if (appearing != null) {
ioFuture = appearing;
} else {
AtomicReference<Thread> threadRef = new AtomicReference<>(Thread.currentThread());
futureResult.addCancelHandler(new Cancellable() {
public Cancellable cancel() {
final Thread thread = threadRef.get();
if (thread != null) {
thread.interrupt();
}
return this;
}
});
try {
doAuthenticate(configuration, futureResult);
} finally {
threadRef.set(null);
}
ioFuture = futureResult.getIoFuture();
}
}
try {
return ioFuture.get();
} catch (AuthenticationException e) {
throw e;
} catch (IOException e) {
throw new AuthenticationException(e);
}
}
void doAuthenticate(final AuthenticationConfiguration configuration, FutureResult<ConnectionPeerIdentity> futureResult) {
Assert.checkNotNullParam("configuration", configuration);
final ConnectionImpl connection = this.connection;
assert ! configuration.equals(connection.getAuthenticationConfiguration());
if (! connection.supportsRemoteAuth()) {
futureResult.setException(log.authenticationNotSupported());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
final AuthenticationContextConfigurationClient client = CLIENT;
Authentication authentication;
final IntIndexHashMap<Authentication> authMap = this.authMap;
final ThreadLocalRandom random = ThreadLocalRandom.current();
int id;
do {
id = random.nextInt();
} while (id == 0 || id == 1 || authMap.containsKey(id) || authMap.putIfAbsent(authentication = new Authentication(id)) != null);
final int finalId = id;
SaslClient saslClient;
boolean intr = Thread.currentThread().isInterrupted();
if (intr) {
futureResult.setException(log.authenticationInterrupted());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
try {
final Principal principal = client.getPrincipal(configuration);
final ConnectionHandler connectionHandler = connection.getConnectionHandler();
// try each mech in turn, unless the peer explicitly rejects
Set<String> mechanisms = new LinkedHashSet<>(offeredMechanisms);
while (! mechanisms.isEmpty()) {
final SSLSession sslSession = connectionHandler.getSslSession();
UnaryOperator<SaslClientFactory> factoryOperator = this.factoryOperator;
try {
saslClient = client.createSaslClient(connection.getPeerURI(), configuration, mechanisms, factoryOperator, sslSession);
} catch (SaslException e) {
futureResult.setException(log.authenticationNoSaslClient(e));
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
if (saslClient == null) {
// break out to "no mechs left" error
break;
}
byte[] response;
try {
if (saslClient.hasInitialResponse()) {
response = saslClient.evaluateChallenge(NO_BYTES);
} else {
response = null;
}
connectionHandler.sendAuthRequest(id, saslClient.getMechanismName(), response);
if (! connectionHandler.isOpen()) {
safeDispose(saslClient);
futureResult.setException(log.authenticationExceptionClosed());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
} catch (IOException e) {
// including SaslException
authMap.remove(authentication);
safeDispose(saslClient);
futureResult.setException(log.authenticationExceptionIo(e));
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
// the main loop
byte[] challenge;
int status;
for (;;) {
synchronized (authentication) {
status = authentication.getStatus();
while (status == WAITING) {
try {
authentication.wait();
} catch (InterruptedException e) {
intr = true;
}
status = authentication.getStatus();
}
challenge = authentication.getSaslBytes();
authentication.setStatus(WAITING);
authentication.setSaslBytes(null);
}
if (status == CHALLENGE) {
try {
response = saslClient.evaluateChallenge(challenge);
} catch (SaslException e) {
log.tracef(e, "Mechanism failed (client): \"%s\"", saslClient.getMechanismName());
mechanisms.remove(saslClient.getMechanismName());
safeDispose(saslClient);
break;
}
try {
connectionHandler.sendAuthResponse(id, response);
if (! connectionHandler.isOpen()) {
safeDispose(saslClient);
futureResult.setException(log.authenticationExceptionClosed());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
} catch (IOException e) {
safeDispose(saslClient);
futureResult.setException(log.authenticationExceptionIo(e));
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
// retry loop
} else if (status == SUCCESS) {
if (challenge != null) {
try {
response = saslClient.evaluateChallenge(challenge);
} catch (SaslException e) {
log.tracef(e, "Mechanism failed (client, possibly failed to verify server): \"%s\"", saslClient.getMechanismName());
mechanisms.remove(saslClient.getMechanismName());
safeDispose(saslClient);
break;
}
if (response != null) {
try {
connectionHandler.sendAuthDelete(id);
} catch (IOException ignored) {
log.trace("Send failed", ignored);
}
safeDispose(saslClient);
futureResult.setException(log.authenticationExtraResponse());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
}
}
safeDispose(saslClient);
// todo: we could use a phantom ref to clean up the ID, but the benefits are dubious
final SaslClient finalSaslClient = saslClient;
futureResult.setResult(constructIdentity(conf -> {
final Object principalObj = finalSaslClient.getNegotiatedProperty(WildFlySasl.PRINCIPAL);
return new ConnectionPeerIdentity(conf, principalObj instanceof Principal ? (Principal) principalObj : principal, finalId, connection);
}));
return;
} else if (status == REJECT) {
// auth rejected (server)
try {
connectionHandler.sendAuthDelete(id);
} catch (IOException ignored) {
log.trace("Send failed", ignored);
}
safeDispose(saslClient);
futureResult.setException(log.serverRejectedAuthentication());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
} else if (status == CLOSED) {
safeDispose(saslClient);
futureResult.setException(log.authenticationExceptionClosed());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
} else if (status == DELETE) {
safeDispose(saslClient);
futureResult.setException(log.serverRejectedAuthentication());
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
} else {
throw Assert.unreachableCode();
}
}
}
// calculate what mechanisms we've tried
Set<String> triedMechs = new HashSet<>(offeredMechanisms);
triedMechs.removeAll(mechanisms);
// whatever is left is what we've tried
Iterator<String> iterator = triedMechs.iterator();
String triedStr;
if (iterator.hasNext()) {
StringBuilder b = new StringBuilder();
b.append(iterator.next());
while (iterator.hasNext()) {
b.append(',').append(iterator.next());
}
triedStr = b.toString();
} else {
triedStr = "(none)";
}
futureResult.setException(log.noAuthMechanismsLeft(triedStr));
futureAuths.remove(configuration, futureResult.getIoFuture());
return;
} finally {
if (intr) Thread.currentThread().interrupt();
}
}
private static void safeDispose(final SaslClient saslClient) {
try {
saslClient.dispose();
} catch (SaslException ignored) {
}
}
private static final int WAITING = 0;
private static final int CHALLENGE = 1;
private static final int SUCCESS = 2;
private static final int REJECT = 3;
private static final int DELETE = 4;
private static final int CLOSED = 5;
void receiveChallenge(final int id, final byte[] challenge) {
final Authentication authentication = authMap.get(id);
if (authentication != null) {
synchronized (authentication) {
authentication.setSaslBytes(challenge);
authentication.setStatus(CHALLENGE);
authentication.notifyAll();
}
}
}
void receiveSuccess(final int id, final byte[] challenge) {
final Authentication authentication = authMap.get(id);
if (authentication != null) {
synchronized (authentication) {
authentication.setSaslBytes(challenge);
authentication.setStatus(SUCCESS);
authentication.notifyAll();
}
}
}
void receiveReject(final int id) {
final Authentication authentication = authMap.get(id);
if (authentication != null) {
synchronized (authentication) {
authentication.setStatus(REJECT);
authentication.notifyAll();
}
}
}
void receiveDeleteAck(final int id) {
final Authentication authentication = authMap.removeKey(id);
if (authentication != null) {
synchronized (authentication) {
authentication.setStatus(DELETE);
authentication.notifyAll();
}
}
}
void connectionClosed() {
Iterator<Authentication> iterator = authMap.iterator();
while (iterator.hasNext()) {
final Authentication authentication = iterator.next();
iterator.remove();
synchronized (authentication) {
authentication.setStatus(CLOSED);
authentication.notifyAll();
}
}
}
/**
* Get the anonymous identity for this context.
*
* @return the anonymous identity (not {@code null})
*/
public ConnectionPeerIdentity getAnonymousIdentity() {
return anonymousIdentity;
}
ConnectionPeerIdentity getConnectionIdentity() {
return connectionIdentity;
}
/**
* Get the current identity.
*
* @return the current identity (not {@code null})
*/
public ConnectionPeerIdentity getCurrentIdentity() {
final ConnectionPeerIdentity currentIdentity = (ConnectionPeerIdentity) super.getCurrentIdentity();
return currentIdentity == null ? anonymousIdentity : currentIdentity;
}
static final class Authentication {
private final int id;
private byte[] saslBytes;
private int status;
Authentication(final int id) {
this.id = id;
}
int getId() {
return id;
}
byte[] getSaslBytes() {
return saslBytes;
}
void setSaslBytes(final byte[] saslBytes) {
this.saslBytes = saslBytes;
}
int getStatus() {
return status;
}
void setStatus(final int status) {
this.status = status;
}
}
}