/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.activemq.transport.amqp.client.sasl; import javax.security.sasl.SaslException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import org.apache.qpid.proton.engine.Sasl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Manage the SASL authentication process */ public class SaslAuthenticator { private static final Logger LOG = LoggerFactory.getLogger(SaslAuthenticator.class); private final Sasl sasl; private final String username; private final String password; private final String authzid; private Mechanism mechanism; private String mechanismRestriction; /** * Create the authenticator and initialize it. * * @param sasl The Proton SASL entry point this class will use to manage the authentication. * @param username The user name that will be used to authenticate. * @param password The password that will be used to authenticate. * @param authzid The authzid used when authenticating (currently only with PLAIN) * @param mechanismRestriction A particular mechanism to use (if offered by the server) or null to allow selection. */ public SaslAuthenticator(Sasl sasl, String username, String password, String authzid, String mechanismRestriction) { this.sasl = sasl; this.username = username; this.password = password; this.authzid = authzid; this.mechanismRestriction = mechanismRestriction; } /** * Process the SASL authentication cycle until such time as an outcome is determine. This * method must be called by the managing entity until the return value is true indicating a * successful authentication or a JMSSecurityException is thrown indicating that the * handshake failed. * * @throws SecurityException */ public boolean authenticate() throws SecurityException { switch (sasl.getState()) { case PN_SASL_IDLE: handleSaslInit(); break; case PN_SASL_STEP: handleSaslStep(); break; case PN_SASL_FAIL: handleSaslFail(); break; case PN_SASL_PASS: return true; default: } return false; } private void handleSaslInit() throws SecurityException { try { String[] remoteMechanisms = sasl.getRemoteMechanisms(); if (remoteMechanisms != null && remoteMechanisms.length != 0) { mechanism = findMatchingMechanism(remoteMechanisms); if (mechanism != null) { mechanism.setUsername(username); mechanism.setPassword(password); mechanism.setAuthzid(authzid); // TODO - set additional options from URI. // TODO - set a host value. sasl.setMechanisms(mechanism.getName()); byte[] response = mechanism.getInitialResponse(); if (response != null && response.length != 0) { sasl.send(response, 0, response.length); } } else { // TODO - Better error message. throw new SecurityException("Could not find a matching SASL mechanism for the remote peer."); } } } catch (SaslException se) { // TODO - Better error message. SecurityException jmsse = new SecurityException("Exception while processing SASL init."); jmsse.initCause(se); throw jmsse; } } private Mechanism findMatchingMechanism(String... remoteMechanisms) { Mechanism match = null; List<Mechanism> found = new ArrayList<>(); for (String remoteMechanism : remoteMechanisms) { if (mechanismRestriction != null && !mechanismRestriction.equals(remoteMechanism)) { LOG.debug("Skipping {} mechanism because it is not the configured mechanism restriction {}", remoteMechanism, mechanismRestriction); continue; } Mechanism mechanism = null; if (remoteMechanism.equalsIgnoreCase("PLAIN")) { mechanism = new PlainMechanism(); } else if (remoteMechanism.equalsIgnoreCase("ANONYMOUS")) { mechanism = new AnonymousMechanism(); } else if (remoteMechanism.equalsIgnoreCase("CRAM-MD5")) { mechanism = new CramMD5Mechanism(); } else { LOG.debug("Unknown remote mechanism {}, skipping", remoteMechanism); continue; } if (mechanism.isApplicable(username, password)) { found.add(mechanism); } } if (!found.isEmpty()) { // Sorts by priority using Mechanism comparison and return the last value in // list which is the Mechanism deemed to be the highest priority match. Collections.sort(found); match = found.get(found.size() - 1); } LOG.info("Best match for SASL auth was: {}", match); return match; } private void handleSaslStep() throws SecurityException { try { if (sasl.pending() != 0) { byte[] challenge = new byte[sasl.pending()]; sasl.recv(challenge, 0, challenge.length); byte[] response = mechanism.getChallengeResponse(challenge); sasl.send(response, 0, response.length); } } catch (SaslException se) { // TODO - Better error message. SecurityException jmsse = new SecurityException("Exception while processing SASL step."); jmsse.initCause(se); throw jmsse; } } private void handleSaslFail() throws SecurityException { // TODO - Better error message. throw new SecurityException("Client failed to authenticate"); } }