/* * JBoss, Home of Professional Open Source. * Copyright 2008, Red Hat Middleware LLC, and individual contributors * as indicated by the @author tags. See the copyright.txt file in the * distribution for a full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.picketlink.identity.federation.bindings.jboss.auth; import org.jboss.security.SecurityConstants; import org.jboss.security.SimplePrincipal; import org.jboss.security.auth.callback.ObjectCallback; import org.picketlink.common.ErrorCodes; import org.picketlink.common.exceptions.ProcessingException; import org.picketlink.common.exceptions.fed.WSTrustException; import org.picketlink.common.util.DocumentUtil; import org.picketlink.common.util.StringUtil; import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkGroup; import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkPrincipal; import org.picketlink.identity.federation.core.constants.AttributeConstants; import org.picketlink.identity.federation.core.constants.PicketLinkFederationConstants; import org.picketlink.identity.federation.core.factories.JBossAuthCacheInvalidationFactory.TimeCacheExpiry; import org.picketlink.identity.federation.core.saml.v2.util.AssertionUtil; import org.picketlink.identity.federation.core.wstrust.STSClient; import org.picketlink.identity.federation.core.wstrust.STSClientConfig; import org.picketlink.identity.federation.core.wstrust.STSClientPool; import org.picketlink.identity.federation.bindings.stspool.STSClientPoolFactory; import org.picketlink.identity.federation.core.wstrust.STSClientConfig.Builder; import org.picketlink.identity.federation.core.wstrust.SamlCredential; import org.picketlink.identity.federation.core.wstrust.auth.AbstractSTSLoginModule; import org.picketlink.identity.federation.core.wstrust.plugins.saml.SAMLUtil; import org.picketlink.identity.federation.saml.v2.assertion.AssertionType; import org.picketlink.identity.federation.saml.v2.assertion.BaseIDAbstractType; import org.picketlink.identity.federation.saml.v2.assertion.NameIDType; import org.picketlink.identity.federation.saml.v2.assertion.SubjectType; import org.w3c.dom.Element; import javax.security.auth.Subject; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.login.LoginException; import javax.xml.datatype.XMLGregorianCalendar; import javax.xml.transform.Source; import javax.xml.ws.Dispatch; import java.security.Principal; import java.security.acl.Group; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; /** * <p> This {@code LoginModule} authenticates clients by validating their SAML assertions with an external security token service * (such as PicketLinkSTS). If the supplied assertion contains roles, these roles are extracted and included in the {@code Group} * returned by the {@code getRoleSets} method. </p> <p> This module defines the following module options: <ul> <li> configFile - * this property identifies the properties file that will be used to establish communication with the external security token * service. </li> <li> cache.invalidation: set it to true if you require invalidation of JBoss Auth Cache at SAML Principal * expiration. </li> <li> jboss.security.security_domain: name of the security domain where this login module is configured. This is * only required if the cache.invalidation option is configured. </li> <li> roleKey: a comma separated list of strings that define * the attributes in SAML assertion for user roles </li> <li> localValidation: if you want to validate the assertion locally for * signature and expiry </li> <li> localValidationSecurityDomain: the security domain for the trust store information (via the * JaasSecurityDomain) </li> <li> tokenEncodingType: encoding type of SAML token delivered via http request's header. Possible * values are: base64 - content encoded as base64. In case of encoding will vary between base64 and gzip use base64 and LoginModule * will detect gzipped data. gzip - gzipped content encoded as base64 none - content not encoded in any way </li> <li> * samlTokenHttpHeader - name of http request header to fetch SAML token from. For example: "Authorize" </li> <li> * samlTokenHttpHeaderRegEx - Java regular expression to be used to get SAML token from "samlTokenHttpHeader". Example: use: * ."(.)".* to parse SAML token from header content like this: SAML_assertion="HHDHS=", at the same time set * samlTokenHttpHeaderRegExGroup to 1. </li> <li> samlTokenHttpHeaderRegExGroup - Group value to be used when parsing out value of * http request header specified by "samlTokenHttpHeader" using "samlTokenHttpHeaderRegEx". </li> </ul> </p> <p> Any properties * specified besides the above properties are assumed to be used to configure how the {@code STSClient} will connect to the STS. For * example, the JBossWS {@code StubExt.PROPERTY_SOCKET_FACTORY} can be specified in order to inform the socket factory that must be * used to connect to the STS. All properties will be set in the request context of the {@code Dispatch} instance used by the {@code * STSClient} to send requests to the STS. </p> <p> An example of a {@code configFile} can be seen bellow: * * <pre> * serviceName=PicketLinkSTS * portName=PicketLinkSTSPort * endpointAddress=http://localhost:8080/picketlink-sts/PicketLinkSTS * username=JBoss * password=JBoss * </pre> * * The first three properties specify the STS endpoint URL, service name, and port name. The last two properties specify the * username and password that are to be used by the application server to authenticate to the STS and have the SAML assertions * validated. </p> <p> <b>NOTE:</b> Sub-classes can use {@link #getSTSClient()} method to customize the {@link STSClient} class to * make calls to STS/ </p> * * @author <a href="mailto:sguilhen@redhat.com">Stefan Guilhen</a> * @author Anil.Saldhana@redhat.com */ @SuppressWarnings("unchecked") public abstract class SAML2STSCommonLoginModule extends SAMLTokenFromHttpRequestAbstractLoginModule { protected String stsConfigurationFile; protected Principal principal; protected SamlCredential credential; protected AssertionType assertion; protected boolean enableCacheInvalidation = false; protected String securityDomain = null; protected boolean localValidation = false; protected String localValidationSecurityDomain; protected String roleKey = AttributeConstants.ROLE_IDENTIFIER_ASSERTION; /** * Maximal number of clients in the STS Client Pool. */ protected int initialClientsInPool = 0; /** * Options that are computed by this login module. Few options are removed and the rest are set in the dispatch sts call */ protected Map<String, Object> options = new HashMap<String, Object>(); /** * Original Options that are sent by the JDK JAAS Framework */ protected Map<String, Object> rawOptions = new HashMap<String, Object>(); /** * This is an option that should identify the configuration file for WSTrustClient. */ public static final String STS_CONFIG_FILE = "configFile"; /** * Key to specify the end point address */ public static final String ENDPOINT_ADDRESS = "endpointAddress"; /** * Key to specify the port name */ public static final String PORT_NAME = "portName"; /** * Key to specify the service name */ public static final String SERVICE_NAME = "serviceName"; /** * Key to specify the username */ public static final String USERNAME_KEY = "username"; /** * Key to specify the password */ public static final String PASSWORD_KEY = "password"; // A variable used by the unit test to pass local validation protected boolean localTestingOnly = false; /** * Paramater name. */ public static final String INITIAL_CLIENTS_IN_POOL = AbstractSTSLoginModule.INITIAL_CLIENTS_IN_POOL; /* * (non-Javadoc) * * @see org.jboss.security.auth.spi.AbstractServerLoginModule#initialize(javax.security.auth.Subject, * javax.security.auth.callback.CallbackHandler, java.util.Map, java.util.Map) */ @Override public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) { super.initialize(subject, callbackHandler, sharedState, options); this.options.putAll(options); this.rawOptions.putAll(options); if (logger.isTraceEnabled()) { logger.trace(options.toString()); } // save the config file and cache validation options, removing them from the map - all remaining properties will // be set in the request context of the Dispatch instance used to send requests to the STS. this.stsConfigurationFile = (String) this.options.remove(STS_CONFIG_FILE); String cacheInvalidation = (String) this.options.remove("cache.invalidation"); if (cacheInvalidation != null && !cacheInvalidation.isEmpty()) { this.enableCacheInvalidation = Boolean.parseBoolean(cacheInvalidation); this.securityDomain = (String) this.options.remove(SecurityConstants.SECURITY_DOMAIN_OPTION); if (this.securityDomain == null || this.securityDomain.isEmpty()) { throw logger.optionNotSet(SecurityConstants.SECURITY_DOMAIN_OPTION); } } String roleKeyStr = (String) options.get("roleKey"); if (StringUtil.isNotNull(roleKeyStr)) { roleKey = roleKeyStr.trim(); } String localValidationStr = (String) options.get("localValidation"); if (StringUtil.isNotNull(localValidationStr)) { localValidation = Boolean.parseBoolean(localValidationStr); localValidationSecurityDomain = (String) options.get("localValidationSecurityDomain"); if (localValidationSecurityDomain == null) { logger.error(ErrorCodes.LOCAL_VALIDATION_SEC_DOMAIN_MUST_BE_SPECIFIED); throw logger.optionNotSet("localValidationSecurityDomain"); } if (localValidationSecurityDomain.startsWith("java:") == false) { localValidationSecurityDomain = SecurityConstants.JAAS_CONTEXT_ROOT + "/" + localValidationSecurityDomain; } String localTestingOnlyStr = (String) options.get("localTestingOnly"); if (StringUtil.isNotNull(localTestingOnlyStr)) { localTestingOnly = Boolean.valueOf(localTestingOnlyStr); } } String initialClientsInPoolString = (String) options.get(INITIAL_CLIENTS_IN_POOL); if (StringUtil.isNotNull(initialClientsInPoolString)) { try { this.initialClientsInPool = Integer.parseInt(initialClientsInPoolString); } catch (Exception e) { logger.cannotParseParameterValue(initialClientsInPoolString, e); } } } /* * (non-Javadoc) * * @see org.jboss.security.auth.spi.AbstractServerLoginModule#login() */ @Override public boolean login() throws LoginException { // if shared data exists, set our principal and assertion variables. if (super.login()) { Object sharedPrincipal = super.sharedState.get("javax.security.auth.login.name"); if (sharedPrincipal instanceof Principal) { this.principal = (Principal) sharedPrincipal; } else { try { this.principal = createIdentity(sharedPrincipal.toString()); } catch (Exception e) { throw logger.authFailedToCreatePrincipal(e); } } Object credential = super.sharedState.get("javax.security.auth.login.password"); if (credential instanceof SamlCredential) { this.credential = (SamlCredential) credential; } else { throw logger.authSharedCredentialIsNotSAMLCredential(credential.getClass().getName()); } return true; } // obtain the assertion from the callback handler. ObjectCallback callback = new ObjectCallback(null); Element assertionElement = null; try { if (getSamlTokenHttpHeader() != null) { this.credential = getCredentialFromHttpRequest(); } else { super.callbackHandler.handle(new Callback[]{callback}); if (callback.getCredential() instanceof String) { callback.setCredential(new SamlCredential(DocumentUtil.getDocument(callback.getCredential().toString()) .getDocumentElement())); } if (callback.getCredential() instanceof SamlCredential == false) { throw logger.authSharedCredentialIsNotSAMLCredential(callback.getCredential().getClass().getName()); } this.credential = (SamlCredential) callback.getCredential(); } assertionElement = this.credential.getAssertionAsElement(); } catch (Exception e) { throw logger.authErrorHandlingCallback(e); } // if there is no shared data, validate the assertion using the STS. if (localValidation) { logger.trace("Local Validation is being Performed"); try { boolean isValid = localValidation(assertionElement); if (isValid) { logger.trace("Local Validation passed."); } } catch (Exception e) { LoginException le = new LoginException(); le.initCause(e); throw le; } } else { logger.trace("Local Validation is disabled. Verifying with STS"); // sts config file has to be present to call STS (using sts client) if (this.stsConfigurationFile == null) { throw logger.authSTSConfigFileNotFound(); } // send the assertion to the STS for validation. STSClient client = this.getSTSClient(); try { boolean isValid = client.validateToken(assertionElement); // if the STS says the assertion is invalid, throw an exception to signal that authentication has failed. if (isValid == false) { throw logger.authInvalidSAMLAssertionBySTS(); } } catch (WSTrustException we) { throw logger.authAssertionValidationError(we); } } // if the assertion is valid, create a principal containing the assertion subject. try { this.assertion = SAMLUtil.fromElement(assertionElement); SubjectType subject = assertion.getSubject(); if (subject != null) { BaseIDAbstractType baseID = subject.getSubType().getBaseID(); if (baseID instanceof NameIDType) { NameIDType nameID = (NameIDType) baseID; this.principal = new PicketLinkPrincipal(nameID.getValue()); // If the user has configured cache invalidation of subject based on saml token expiry if (enableCacheInvalidation) { TimeCacheExpiry cacheExpiry = this.getCacheExpiry(); XMLGregorianCalendar expiry = AssertionUtil.getExpiration(assertion); if (expiry != null) { Date expiryDate = expiry.toGregorianCalendar().getTime(); logger .trace("Creating Cache Entry for JBoss at [" + new Date() + "] , with expiration set to SAML expiry = " + expiryDate); cacheExpiry.register(securityDomain, expiryDate, principal); } else { logger.samlAssertionWithoutExpiration(assertion.getID()); } } } } } catch (Exception e) { throw logger.authFailedToParseSAMLAssertion(e); } // if password-stacking has been configured, set the principal and the assertion in the shared map. if (getUseFirstPass()) { super.sharedState.put("javax.security.auth.login.name", this.principal); super.sharedState.put("javax.security.auth.login.password", this.credential); } return (super.loginOk = true); } /* (non-Javadoc) * @see org.jboss.security.auth.spi.AbstractServerLoginModule#commit() */ @Override public boolean commit() throws LoginException { if (super.commit()) { final boolean added = subject.getPublicCredentials().add(this.credential); if (added && logger.isTraceEnabled()) { logger.trace("Added Credential " + this.credential); } return true; } else { return false; } } /** * Called if the overall authentication failed (phase 2). */ @Override public boolean abort() throws LoginException { clearState(); super.abort(); return true; } @Override public boolean logout() throws LoginException { clearState(); super.logout(); return true; } private void clearState() { AbstractSTSLoginModule.removeAllSamlCredentials(subject); credential = null; } /* * (non-Javadoc) * * @see org.jboss.security.auth.spi.AbstractServerLoginModule#getIdentity() */ @Override protected Principal getIdentity() { return this.principal; } /* * (non-Javadoc) * * @see org.jboss.security.auth.spi.AbstractServerLoginModule#getRoleSets() */ @Override protected Group[] getRoleSets() throws LoginException { if (this.assertion == null) { try { this.assertion = SAMLUtil.fromElement(this.credential.getAssertionAsElement()); } catch (Exception e) { throw logger.authFailedToParseSAMLAssertion(e); } } if (logger.isTraceEnabled()) { try { logger.trace("Assertion from where roles will be sought = " + AssertionUtil.asString(assertion)); } catch (ProcessingException ignore) { } } List<String> roleKeys = new ArrayList<String>(); if (StringUtil.isNotNull(roleKey)) { roleKeys.addAll(StringUtil.tokenize(roleKey)); } String groupName = SecurityConstants.ROLES_IDENTIFIER; Group rolesGroup = new PicketLinkGroup(groupName); List<String> roles = AssertionUtil.getRoles(assertion, roleKeys); for (String role : roles) { rolesGroup.addMember(new SimplePrincipal(role)); } return new Group[]{rolesGroup}; } /** * Get the {@link STSClient} object with which we can make calls to the STS * * @return */ protected STSClient getSTSClient() { /* * Builder builder = new Builder(this.stsConfigurationFile); STSClient client = new STSClient(builder.build()); */ Builder builder = null; STSClient client = null; if (rawOptions.containsKey(STS_CONFIG_FILE)) { builder = new Builder(this.stsConfigurationFile); } else { builder = new Builder(); builder.endpointAddress((String) rawOptions.get(ENDPOINT_ADDRESS)); builder.portName((String) rawOptions.get(PORT_NAME)).serviceName((String) rawOptions.get(SERVICE_NAME)); builder.username((String) rawOptions.get(USERNAME_KEY)).password((String) rawOptions.get(PASSWORD_KEY)); String passwordString = (String) rawOptions.get(PASSWORD_KEY); if (passwordString != null && passwordString.startsWith(PicketLinkFederationConstants.PASS_MASK_PREFIX)) { // password is masked String salt = (String) rawOptions.get(PicketLinkFederationConstants.SALT); if (StringUtil.isNullOrEmpty(salt)) { throw logger.optionNotSet("Salt"); } String iCount = (String) rawOptions.get(PicketLinkFederationConstants.ITERATION_COUNT); if (StringUtil.isNullOrEmpty(iCount)) { throw logger.optionNotSet("Iteration Count"); } int iterationCount = Integer.parseInt(iCount); try { builder.password(StringUtil.decode(passwordString, salt, iterationCount)); } catch (Exception e) { throw logger.unableToDecodePasswordError(passwordString); } } } STSClientConfig config = builder.build(); STSClientPool pool = STSClientPoolFactory.getPoolInstance(); if (initialClientsInPool > 0) { pool.createPool(initialClientsInPool, config); } client = pool.getClient(config); // if the login module options map still contains any properties, assume they are for configuring the connection // to the STS and set them in the Dispatch request context. if (!this.options.isEmpty()) { Dispatch<Source> dispatch = client.getDispatch(); for (Map.Entry<String, ?> entry : this.options.entrySet()) { dispatch.getRequestContext().put(entry.getKey(), entry.getValue()); } } return client; } /** * Locally validate the SAML Assertion element * * @param assertionElement * * @return * * @throws Exception */ protected abstract boolean localValidation(Element assertionElement) throws Exception; protected abstract TimeCacheExpiry getCacheExpiry() throws Exception; }