/*
* JBoss, Home of Professional Open Source.
* Copyright 2015 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.sasl.gs2;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.wildfly.security.sasl.gs2.Gs2.GS2_KRB5;
import static org.wildfly.security.sasl.gs2.Gs2.GS2_KRB5_PLUS;
import static org.wildfly.security.sasl.gs2.Gs2.OID_KRB5;
import static org.wildfly.security.sasl.gs2.Gs2.OID_SPNEGO;
import static org.wildfly.security.sasl.gs2.Gs2.SPNEGO;
import static org.wildfly.security.sasl.gs2.Gs2.SPNEGO_PLUS;
import static org.wildfly.security.sasl.gssapi.JaasUtil.loginClient;
import static org.wildfly.security.sasl.gssapi.JaasUtil.loginServer;
import static org.wildfly.security.sasl.gssapi.TestKDC.LDAP_PORT;
import java.io.IOException;
import java.net.URI;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import javax.naming.NamingException;
import javax.naming.directory.DirContext;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;
import org.ietf.jgss.GSSCredential;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.Oid;
import org.junit.BeforeClass;
import org.junit.Test;
import org.wildfly.common.function.ExceptionSupplier;
import org.wildfly.security.auth.callback.ChannelBindingCallback;
import org.wildfly.security.auth.client.AuthenticationConfiguration;
import org.wildfly.security.auth.client.AuthenticationContext;
import org.wildfly.security.auth.client.ClientUtils;
import org.wildfly.security.auth.client.MatchRule;
import org.wildfly.security.auth.realm.ldap.DirContextFactory;
import org.wildfly.security.auth.realm.ldap.LdapSecurityRealmBuilder;
import org.wildfly.security.auth.realm.ldap.SimpleDirContextFactoryBuilder;
import org.wildfly.security.auth.server.SecurityRealm;
import org.wildfly.security.auth.util.RegexNameRewriter;
import org.wildfly.security.credential.GSSKerberosCredential;
import org.wildfly.security.sasl.WildFlySasl;
import org.wildfly.security.sasl.gssapi.GssapiTestSuite;
import org.wildfly.security.sasl.test.BaseTestCase;
import org.wildfly.security.sasl.test.SaslServerBuilder;
import org.wildfly.security.sasl.util.ChannelBindingSaslClientFactory;
import org.wildfly.security.sasl.util.PropertiesSaslClientFactory;
import org.wildfly.security.sasl.util.ProtocolSaslClientFactory;
import org.wildfly.security.sasl.util.ServerNameSaslClientFactory;
/**
* Client and server side tests for the GS2 SASL mechanism.
*
* @author <a href="mailto:fjuma@redhat.com">Farah Juma</a>
*/
public class Gs2SuiteChild extends BaseTestCase {
private static final String TEST_SERVER_1 = "test_server_1";
private static Subject clientSubject;
private static Subject serverSubject;
private SaslServer saslServer;
private SaslClient saslClient;
@BeforeClass
public static void init() throws LoginException {
clientSubject = loginClient();
serverSubject = loginServer(GssapiTestSuite.serverKeyTab);
}
@Test
public void testChannelBindingIndirect_Server() throws Exception {
Map<String, Object> props = new HashMap<String, Object>();
// No properties are set, an appropriate Gs2SaslServer should be returned
saslServer = getIndirectSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, props, null, null);
assertEquals(GS2_KRB5, saslServer.getMechanismName());
// Require channel binding
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
saslServer = getIndirectSaslServer(GS2_KRB5_PLUS, "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertEquals(GS2_KRB5_PLUS, saslServer.getMechanismName());
// If channel binding is required even though a non-PLUS mechanism is specified, no server should be returned
saslServer = getIndirectSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, props, null, null);
assertNull(saslServer);
}
@Test
public void testChannelBindingDirect_Server() {
SaslServerFactory factory = obtainSaslServerFactory(Gs2SaslServerFactory.class);
assertNotNull("SaslServerFactory not registered", factory);
String[] mechanisms;
Map<String, Object> props = new HashMap<String, Object>();
// No properties set
mechanisms = factory.getMechanismNames(props);
assertMechanisms(new String[]{GS2_KRB5, GS2_KRB5_PLUS}, mechanisms);
// Require channel binding
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
mechanisms = factory.getMechanismNames(props);
assertMechanisms(new String[]{GS2_KRB5_PLUS}, mechanisms);
}
@Test
public void testChannelBindingIndirect_Client() throws Exception {
Map<String, Object> props = new HashMap<String, Object>();
// No properties are set, an appropriate Gs2SaslClient should be returned
saslClient = getIndirectSaslClient(new String[]{GS2_KRB5}, null, "sasl", TEST_SERVER_1, props, null, null);
assertEquals(Gs2SaslClient.class, saslClient.getClass());
assertEquals(GS2_KRB5, saslClient.getMechanismName());
// If channel binding is required even though only non-PLUS mechanisms are specified, no client should be returned
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
saslClient = getIndirectSaslClient(new String[]{"GS2-DT4PIK22T6A", GS2_KRB5}, null, "sasl", TEST_SERVER_1, props, null, null);
assertNull(saslClient);
// If channel binding is required, an appropriate Gs2SaslClient should be returned
saslClient = getIndirectSaslClient(new String[]{"GS2-DT4PIK22T6A-PLUS", GS2_KRB5_PLUS}, null, "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertEquals(Gs2SaslClient.class, saslClient.getClass());
assertEquals(GS2_KRB5_PLUS, saslClient.getMechanismName());
}
@Test
public void testChannelBindingDirect_Client() {
SaslClientFactory factory = obtainSaslClientFactory(Gs2SaslClientFactory.class);
assertNotNull("SaslClientFactory not registered", factory);
String[] mechanisms;
Map<String, Object> props = new HashMap<String, Object>();
// No properties set
mechanisms = factory.getMechanismNames(props);
assertMechanisms(new String[]{ GS2_KRB5, GS2_KRB5_PLUS }, mechanisms);
// Request channel binding
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
mechanisms = factory.getMechanismNames(props);
assertMechanisms(new String[]{GS2_KRB5_PLUS}, mechanisms);
}
// -- Successful authentication exchanges --
@Test
public void testKrb5AuthenticationWithoutChannelBinding() throws Exception {
saslServer = getSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslServer);
assertEquals(GS2_KRB5, saslServer.getMechanismName());
assertFalse(saslServer.isComplete());
saslClient = getSaslClient(new String[] { GS2_KRB5 }, null, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslClient);
assertTrue(saslClient instanceof Gs2SaslClient);
assertTrue(saslClient.hasInitialResponse());
assertFalse(saslClient.isComplete());
byte[] message = evaluateChallenge(new byte[0]);
assertFalse(saslClient.isComplete());
assertFalse(saslServer.isComplete());
message = evaluateResponse(message);
assertTrue(saslServer.isComplete());
assertNotNull(message);
assertFalse(saslClient.isComplete());
message = evaluateChallenge(message);
assertTrue(saslClient.isComplete());
assertNull(message);
assertEquals("jduke@WILDFLY.ORG", saslServer.getAuthorizationID());
}
@Test
public void testKrb5AuthenticationWithChannelBinding() throws Exception {
Map<String, Object> props = new HashMap<String, Object>();
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
saslServer = getSaslServer(GS2_KRB5_PLUS, "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertNotNull(saslServer);
assertEquals(GS2_KRB5_PLUS, saslServer.getMechanismName());
assertFalse(saslServer.isComplete());
saslClient = getSaslClient(new String[]{GS2_KRB5_PLUS}, "jduke@WILDFLY.ORG", "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertNotNull(saslClient);
assertTrue(saslClient instanceof Gs2SaslClient);
assertTrue(saslClient.hasInitialResponse());
assertFalse(saslClient.isComplete());
byte[] message = evaluateChallenge(new byte[0]);
assertFalse(saslClient.isComplete());
assertFalse(saslServer.isComplete());
message = evaluateResponse(message);
assertTrue(saslServer.isComplete());
assertNotNull(message);
assertFalse(saslClient.isComplete());
message = evaluateChallenge(message);
assertTrue(saslClient.isComplete());
assertNull(message);
assertEquals("jduke@WILDFLY.ORG", saslServer.getAuthorizationID());
}
@Test
public void testKrb5AuthenticationWithCredentialPassedInForClientAndServer() throws Exception {
saslServer = getSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null, true);
assertNotNull(saslServer);
assertEquals(GS2_KRB5, saslServer.getMechanismName());
assertFalse(saslServer.isComplete());
saslClient = getSaslClient(new String[] { GS2_KRB5 }, "jduke@WILDFLY.ORG", "sasl", TEST_SERVER_1, Collections.emptyMap(), null, null, true);
assertNotNull(saslClient);
assertTrue(saslClient instanceof Gs2SaslClient);
assertTrue(saslClient.hasInitialResponse());
assertFalse(saslClient.isComplete());
byte[] message = saslClient.evaluateChallenge(new byte[0]);
assertFalse(saslClient.isComplete());
assertFalse(saslServer.isComplete());
message = saslServer.evaluateResponse(message);
assertTrue(saslServer.isComplete());
assertNotNull(message);
assertFalse(saslClient.isComplete());
message = saslClient.evaluateChallenge(message);
assertTrue(saslClient.isComplete());
assertNull(message);
assertEquals("jduke@WILDFLY.ORG", saslServer.getAuthorizationID());
}
// -- Unsuccessful authentication exchanges --
@Test
public void testChannelBindingNotUsedByClientSupportedByServer() throws Exception {
// gs2-cb-flag = "y"
saslClient = getSaslClient(new String[] { GS2_KRB5 }, null, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(),
"tls-unique", new byte[0]);
assertNotNull(saslClient);
saslServer = getSaslServer(GS2_KRB5_PLUS, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), "tls-unique", new byte[0]);
assertNotNull(saslServer);
byte[] message = evaluateChallenge(new byte[0]);
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testChannelBindingUsedByClientUnsupportedByServer() throws Exception {
// gs2-cb-flag = "p"
Map<String, Object> props = new HashMap<String, Object>();
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
saslClient = getSaslClient(new String[] { GS2_KRB5_PLUS }, null, "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertNotNull(saslClient);
saslServer = getSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslServer);
byte[] message = evaluateChallenge(new byte[0]);
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testChannelBindingUnsupportedByClientSupportedByServer() throws Exception {
// gs2-cb-flag = "n"
saslClient = getSaslClient(new String[] { GS2_KRB5 }, null, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslClient);
saslServer = getSaslServer(GS2_KRB5_PLUS, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), "tls-unique", new byte[0]);
assertNotNull(saslServer);
byte[] message = evaluateChallenge(new byte[0]);
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testChannelBindingTypeMismatch() throws Exception {
Map<String, Object> props = new HashMap<String, Object>();
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
saslClient = getSaslClient(new String[]{GS2_KRB5_PLUS}, null, "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertNotNull(saslClient);
saslServer = getSaslServer(GS2_KRB5_PLUS, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), "tls-unique-for-telnet", new byte[0]);
assertNotNull(saslServer);
byte[] message = evaluateChallenge(new byte[0]);
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testChannelBindingDataMismatch() throws Exception {
Map<String, Object> props = new HashMap<String, Object>();
props.put(WildFlySasl.CHANNEL_BINDING_REQUIRED, Boolean.toString(true));
saslClient = getSaslClient(new String[]{GS2_KRB5_PLUS}, null, "sasl", TEST_SERVER_1, props, "tls-unique", new byte[0]);
assertNotNull(saslClient);
saslServer = getSaslServer(GS2_KRB5_PLUS, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), "tls-unique", new byte[1]);
assertNotNull(saslServer);
byte[] message = evaluateChallenge(new byte[0]);
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testUnauthorizedAuthorizationId() throws Exception {
saslServer = getSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslServer);
saslClient = getSaslClient(new String[]{GS2_KRB5}, "sasl/test_server_1@WILDFLY.ORG", "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslClient);
byte[] message = evaluateChallenge(new byte[0]);
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testUnneededNonStdFlag() throws Exception {
saslServer = getSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslServer);
saslClient = getSaslClient(new String[] { GS2_KRB5 }, null, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslClient);
byte[] origMessage = evaluateChallenge(new byte[0]);
assertFalse(saslClient.isComplete());
assertFalse(saslServer.isComplete());
byte[] message = new byte[origMessage.length + 2];
System.arraycopy(origMessage, 0, message, 2, origMessage.length);
message[0] = (byte)'F'; // Insert gs2-nonstd-flag
message[1] = (byte)',';
try {
message = evaluateResponse(message);
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testInvalidGs2Header() throws Exception {
saslServer = getSaslServer(GS2_KRB5, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNotNull(saslServer);
try {
// gs2-header starts with an invalid character
byte[] message = evaluateResponse(new byte[] {98, 44, 44, 1, 0, 110, -126, 1, -13, 48, -126, 1, -17, -96, 3, 2, 1, 5, -95, 3, 2, 1, 14, -94, 7, 3, 5, 0, 32, 0, 0, 0, -93, -126, 1, 11, 97, -126, 1, 7, 48, -126, 1, 3, -96, 3, 2, 1, 5, -95, 13, 27, 11, 87, 73, 76, 68, 70, 76, 89, 46, 79, 82, 71, -94, 32, 48, 30, -96, 3, 2, 1, 0, -95, 23, 48, 21, 27, 4, 115, 97, 115, 108, 27, 13, 116, 101, 115, 116, 95, 115, 101, 114, 118, 101, 114, 95, 49, -93, -127, -54, 48, -127, -57, -96, 3, 2, 1, 16, -94, -127, -65, 4, -127, -68, 85, 26, 77, -98, -85, 110, 17, -61, 12, -36, 34, -105, 37, 126, 2, 74, -98, 47, -23, -108, 57, 2, -4, 110, -71, -79, -99, 8, 71, 11, -90, -118, -23, -122, -115, 3, -105, 31, 52, -50, -104, 35, -7, -14, -102, -39, 110, 74, -17, 55, 78, 67, -52, 74, -59, 85, 40, 89, -8, -61, -109, -69, -126, 31, -100, 62, 37, 78, -20, 99, -24, -28, -54, 112, 34, 87, -4, 57, -46, 97, 118, 43, 103, -74, -39, -59, -16, -88, 8, -122, 81, 83, -103, 83, 49, 54, -20, -125, -110, 18, 26, 87, -22, -111, 71, 122, 110, 83, -33, -92, -94, 114, -92, -30, 114, 22, 46, 73, 38, 58, -117, -118, -23, -18, -91, -14, -42, 84, 37, -4, 90, 116, -77, -41, 93, 82, 54, -69, 114, 124, -82, -102, -50, -83, 17, 117, -86, 106, 50, 78, -122, 54, 57, -27, -89, -85, 125, -104, 110, -38, 75, -25, -85, 91, -77, -7, -68, 112, 87, -125, -28, 34, 71, -62, -34, -110, -122, -120, -86, -93, -41, 41, -34, 91, 88, -114, 112, 83, -92, -127, -54, 48, -127, -57, -96, 3, 2, 1, 16, -94, -127, -65, 4, -127, -68, -12, -3, 100, 43, -53, 16, 56, -68, 107, -81, 105, 26, 123, 115, 94, -94, 119, 36, 65, 109, 68, 26, -61, 22, -68, -68, 29, -36, -80, 80, -66, 24, 74, -7, -5, -43, 37, -75, 26, -33, 50, 89, 81, 125, 67, 64, 27, 104, 24, -42, 37, -19, 13, 65, 95, -25, -19, 23, 58, -42, -43, 88, -42, -1, 121, 87, -12, 17, 55, -116, 81, -107, -22, -56, 0, 99, -56, 56, 67, 57, -127, -3, 73, -56, -100, -74, -78, 27, 7, 58, -47, 23, -12, 20, 15, 65, -77, -36, 14, 122, -95, 45, -9, -116, 89, 87, 82, -117, -60, 22, 55, 104, 103, -71, -12, -45, -1, -44, 106, -117, 91, 83, -44, -60, 122, -100, -89, -65, 43, 107, -124, -57, -82, 113, 72, 77, -84, 121, -90, 57, -28, 90, 80, -33, 97, -62, 10, 124, 67, 97, 110, 87, 20, -78, -14, -9, 84, 64, 78, 28, -63, -78, -29, -93, 29, 111, -34, -128, 96, -53, -25, -84, -39, -44, 85, 96, 0, -35, 35, -100, -123, 7, -112, -26, -89, 14, 92, -28});
fail("Expected SaslException not thrown");
} catch (SaslException expected) {
}
}
@Test
public void testDisallowedMechanism() throws Exception {
// SPNEGO must not be used as a GS2 mechanism (section 14.3 in RFC 5801)
saslServer = getSaslServer(SPNEGO, "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNull(saslServer);
saslClient = getSaslClient(new String[] { SPNEGO, SPNEGO_PLUS }, "bsmith@WILDFLY.ORG", "sasl", TEST_SERVER_1, Collections.<String, Object>emptyMap(), null, null);
assertNull(saslClient);
}
// -- Validate mapping SASL mechanism names to GSS-API OIDs and vice versa --
@Test
public void testGetSaslNameForMechanismOid() throws Exception {
assertEquals(GS2_KRB5, Gs2.getSaslNameForMechanism(new Oid("1.2.840.113554.1.2.2"), false));
assertEquals(SPNEGO_PLUS, Gs2.getSaslNameForMechanism(new Oid("1.3.6.1.5.5.2"), true));
assertEquals("GS2-DT4PIK22T6A-PLUS", Gs2.getSaslNameForMechanism(new Oid("1.3.6.1.5.5.1.1"), true));
}
@Test
public void testGetMechanismForSaslName() throws Exception {
assertEquals(OID_KRB5, Gs2.getMechanismForSaslName(GSSManager.getInstance(), "GS2-KRB5-PLUS"));
assertEquals(OID_SPNEGO, Gs2.getMechanismForSaslName(GSSManager.getInstance(), "SPNEGO"));
}
private SaslServer getIndirectSaslServer(final String mechanism, final String protocol, final String serverName, final Map<String, Object> props,
final String bindingType, final byte[] bindingData) throws SaslException {
try {
return Subject.doAs(serverSubject, new PrivilegedExceptionAction<SaslServer>() {
public SaslServer run() throws SaslException {
//TODO I don't like people having to pass in a callback handler to get this information
CallbackHandler cbh = new IndirectCallbackHandler(bindingType, bindingData);
return Sasl.createSaslServer(mechanism, protocol, serverName, props, cbh);
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
private SaslServer getSaslServer(final String mechanism, final String protocol, final String serverName, final Map<String, Object> props,
final String bindingType, final byte[] bindingData) throws SaslException {
return getSaslServer(mechanism, protocol, serverName, props, bindingType, bindingData, false);
}
private SaslServer getSaslServer(final String mechanism, final String protocol, final String serverName, final Map<String, Object> props,
final String bindingType, final byte[] bindingData, final boolean passCredential) throws SaslException {
GSSCredential credential = null;
if (passCredential) {
try {
credential = Subject.doAs(serverSubject, new PrivilegedExceptionAction<GSSCredential>() {
public GSSCredential run() throws SaslException {
try {
return GSSManager.getInstance().createCredential(null, GSSCredential.INDEFINITE_LIFETIME, OID_KRB5, GSSCredential.ACCEPT_ONLY);
} catch (GSSException e) {
throw new SaslException(e.getMessage());
}
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
final SaslServerBuilder builder = new SaslServerBuilder(Gs2SaslServerFactory.class, mechanism)
.setDontAssertBuiltServer();
final ExceptionSupplier<DirContext, NamingException> dirContextSupplier = () ->
SimpleDirContextFactoryBuilder.builder()
.setProviderUrl(String.format("ldap://localhost:%d/", LDAP_PORT))
.setSecurityPrincipal("uid=Sasl_1,ou=Users,dc=wildfly,dc=org")
.setSecurityCredential("servicepwd")
.build().obtainDirContext(DirContextFactory.ReferralMode.IGNORE);
final SecurityRealm securityRealm = LdapSecurityRealmBuilder.builder()
.setDirContextSupplier(dirContextSupplier)
.setNameRewriter(new RegexNameRewriter(Pattern.compile("(.*)@WILDFLY\\.ORG"), "$1", true))
.identityMapping()
.setSearchDn("dc=wildfly,dc=org")
.searchRecursive()
.setRdnIdentifier("uid")
.build()
.build();
final String realmName = "ldapRealm";
builder.addRealm(realmName, securityRealm);
builder.setDefaultRealmName(realmName);
if (protocol != null) {
builder.setProtocol(protocol);
}
if (serverName != null) {
builder.setServerName(serverName);
}
if (props != null) {
builder.setProperties(props);
}
if (bindingType != null || bindingData != null) {
builder.setChannelBinding(bindingType, bindingData);
}
if (credential != null) {
builder.setCredential(new GSSKerberosCredential(credential));
}
try {
return Subject.doAs(serverSubject, new PrivilegedExceptionAction<SaslServer>() {
public SaslServer run() throws Exception {
return builder.build();
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
private SaslClient getIndirectSaslClient(final String[] mechanisms, final String authorizationId, final String protocol, final String serverName,
final Map<String, Object> props, final String bindingType, final byte[] bindingData) throws SaslException {
try {
return Subject.doAs(clientSubject, new PrivilegedExceptionAction<SaslClient>() {
public SaslClient run() throws SaslException {
//TODO I don't like people having to pass in a callback handler to get this information
CallbackHandler cbh = new IndirectCallbackHandler(bindingType, bindingData);
return Sasl.createSaslClient(mechanisms, authorizationId, protocol, serverName, props, cbh);
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
private SaslClient getSaslClient(final String[] mechanisms, final String authorizationId, final String protocol, final String serverName,
final Map<String, Object> props, final String bindingType, final byte[] bindingData) throws Exception {
return getSaslClient(mechanisms, authorizationId, protocol, serverName, props, bindingType, bindingData, false);
}
private SaslClient getSaslClient(final String[] mechanisms, final String authorizationId, final String protocol, final String serverName,
final Map<String, Object> props, final String bindingType,
final byte[] bindingData, final boolean passCredential) throws Exception {
GSSCredential credential = null;
if (passCredential) {
try {
credential = Subject.doAs(clientSubject, new PrivilegedExceptionAction<GSSCredential>() {
public GSSCredential run() throws SaslException {
try {
return GSSManager.getInstance().createCredential(null, GSSCredential.INDEFINITE_LIFETIME, OID_KRB5, GSSCredential.INITIATE_ONLY);
} catch (GSSException e) {
throw new SaslException(e.getMessage());
}
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
final CallbackHandler cbh = createClientCallbackHandler(mechanisms, authorizationId, credential);
SaslClientFactory clientFactory = obtainSaslClientFactory(Gs2SaslClientFactory.class);
assertNotNull(clientFactory);
if (bindingType != null || bindingData != null) {
clientFactory = new ChannelBindingSaslClientFactory(clientFactory, bindingType, bindingData);
assertNotNull(clientFactory);
}
if (protocol != null) {
clientFactory = new ProtocolSaslClientFactory(clientFactory, protocol);
assertNotNull(clientFactory);
}
if (serverName != null) {
clientFactory = new ServerNameSaslClientFactory(clientFactory, serverName);
assertNotNull(clientFactory);
}
if (props != null) {
clientFactory = new PropertiesSaslClientFactory(clientFactory, props);
assertNotNull(clientFactory);
}
final SaslClientFactory factory = clientFactory;
try {
return Subject.doAs(clientSubject, new PrivilegedExceptionAction<SaslClient>() {
public SaslClient run() throws SaslException {
return factory.createSaslClient(mechanisms, authorizationId, protocol, serverName, props, cbh);
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
private CallbackHandler createClientCallbackHandler(final String[] mechanisms, final String authorizationId, final GSSCredential credential) throws Exception {
final AuthenticationContext context = AuthenticationContext.empty()
.with(
MatchRule.ALL,
AuthenticationConfiguration.EMPTY
.useAuthorizationName(authorizationId)
.useGSSCredential(credential)
.allowSaslMechanisms(mechanisms));
return ClientUtils.getCallbackHandler(new URI("remote://localhost"), context);
}
private byte[] evaluateResponse(final byte[] response) throws SaslException {
try {
return Subject.doAs(serverSubject, new PrivilegedExceptionAction<byte[]>() {
public byte[] run() throws SaslException {
return saslServer.evaluateResponse(response);
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
private byte[] evaluateChallenge(final byte[] challenge) throws SaslException {
try {
return Subject.doAs(clientSubject, new PrivilegedExceptionAction<byte[]>(){
public byte[] run() throws SaslException {
return saslClient.evaluateChallenge(challenge);
}
});
} catch (PrivilegedActionException e) {
if (e.getCause() instanceof SaslException) {
throw (SaslException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
//TODO I don't like the indirect tests having to pass in a callback handler to get this information
private static class IndirectCallbackHandler implements CallbackHandler {
private final String bindingType;
private final byte[] bindingData;
private IndirectCallbackHandler(String bindingType, byte[] bindingData) {
this.bindingType = bindingType;
this.bindingData = bindingData;
}
@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
for (Callback callback : callbacks) {
if (callback instanceof ChannelBindingCallback) {
final ChannelBindingCallback channelBindingCallback = (ChannelBindingCallback) callback;
channelBindingCallback.setBindingType(bindingType);
channelBindingCallback.setBindingData(bindingData);
}
}
}
};
}