/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.ws.security.trust;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.Principal;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;
import org.w3c.dom.Document;
import org.apache.cxf.Bus;
import org.apache.cxf.BusException;
import org.apache.cxf.BusFactory;
import org.apache.cxf.bus.spring.SpringBusFactory;
import org.apache.cxf.common.classloader.ClassLoaderUtils;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.common.security.SimplePrincipal;
import org.apache.cxf.endpoint.EndpointException;
import org.apache.cxf.helpers.DOMUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.PhaseInterceptorChain;
import org.apache.cxf.rt.security.claims.ClaimCollection;
import org.apache.cxf.rt.security.saml.utils.SAMLUtils;
import org.apache.cxf.rt.security.utils.SecurityUtils;
import org.apache.cxf.ws.security.SecurityConstants;
import org.apache.cxf.ws.security.tokenstore.EHCacheTokenStore;
import org.apache.cxf.ws.security.tokenstore.TokenStore;
import org.apache.cxf.ws.security.tokenstore.TokenStoreFactory;
import org.apache.cxf.ws.security.trust.claims.RoleClaimsCallbackHandler;
import org.apache.cxf.ws.security.wss4j.WSS4JInInterceptor;
import org.apache.wss4j.common.saml.SamlAssertionWrapper;
import org.apache.wss4j.common.util.Loader;
import org.apache.wss4j.dom.WSConstants;
import org.apache.wss4j.dom.handler.RequestData;
import org.apache.wss4j.dom.message.token.UsernameToken;
import org.apache.wss4j.dom.validate.Credential;
/**
* A JAAS LoginModule for authenticating a Username/Password to the STS. It can be configured
* either by specifying the various options (documented below) in the JAAS configuration, or
* else by picking up a CXF STSClient from the CXF bus (either the default one, or else one
* that has the same QName as the service name).
*/
public class STSLoginModule implements LoginModule {
/**
* Whether we require roles or not from the STS. If this is not set then the
* WS-Trust validate binding is used. If it is set then the issue binding is
* used, where the Username + Password credentials are passed via "OnBehalfOf"
* (unless the DISABLE_ON_BEHALF_OF property is set to "true", see below). In addition,
* claims are added to the request for the standard "role" ClaimType.
*/
public static final String REQUIRE_ROLES = "require.roles";
/**
* Whether to disable passing Username + Password credentials via "OnBehalfOf". If the
* REQUIRE_ROLES property (see above) is set to "true", then the Issue Binding is used
* and the credentials are passed via OnBehalfOf. If this (DISABLE_ON_BEHALF_OF) property
* is set to "true", then the credentials instead are passed through to the
* WS-SecurityPolicy layer and used depending on the security policy of the STS endpoint.
* For example, if the STS endpoint requires a WS-Security UsernameToken, then the
* credentials are inserted here.
*/
public static final String DISABLE_ON_BEHALF_OF = "disable.on.behalf.of";
/**
* Whether to disable caching of validated credentials or not. The default is "false", meaning that
* caching is enabled. However, caching only applies when token transformation takes place, i.e. when
* the "require.roles" property is set to "true".
*/
public static final String DISABLE_CACHING = "disable.caching";
/**
* The WSDL Location of the STS
*/
public static final String WSDL_LOCATION = "wsdl.location";
/**
* The Service QName of the STS
*/
public static final String SERVICE_NAME = "service.name";
/**
* The Endpoint QName of the STS
*/
public static final String ENDPOINT_NAME = "endpoint.name";
/**
* The default key size to use if using the SymmetricKey KeyType. Defaults to 256.
*/
public static final String KEY_SIZE = "key.size";
/**
* The key type to use. The default is the standard "Bearer" URI.
*/
public static final String KEY_TYPE = "key.type";
/**
* The token type to use. The default is the standard SAML 2.0 URI.
*/
public static final String TOKEN_TYPE = "token.type";
/**
* The WS-Trust namespace to use. The default is the WS-Trust 1.3 namespace.
*/
public static final String WS_TRUST_NAMESPACE = "ws.trust.namespace";
/**
* The location of a Spring configuration file that can be used to configure the
* STS client (for example, to configure the TrustStore if TLS is used). This is
* designed to be used if the service that is being secured is not CXF-based.
*/
public static final String CXF_SPRING_CFG = "cxf.spring.config";
private static final Logger LOG = LogUtils.getL7dLogger(STSLoginModule.class);
private static final String TOKEN_STORE_KEY = "sts.login.module.tokenstore";
private Set<Principal> roles = new HashSet<>();
private Principal userPrincipal;
private Subject subject;
private CallbackHandler callbackHandler;
private boolean requireRoles;
private boolean disableOnBehalfOf;
private boolean disableCaching;
private String wsdlLocation;
private String serviceName;
private String endpointName;
private String cxfSpringCfg;
private int keySize;
private String keyType = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer";
private String tokenType = "http://docs.oasis-open.org/wss/oasis-wss-saml-token-profile-1.1#SAMLV2.0";
private String namespace;
private Map<String, Object> stsClientProperties = new HashMap<>();
@Override
public void initialize(Subject subj, CallbackHandler cbHandler, Map<String, ?> sharedState,
Map<String, ?> options) {
subject = subj;
callbackHandler = cbHandler;
if (options.containsKey(REQUIRE_ROLES)) {
requireRoles = Boolean.parseBoolean((String)options.get(REQUIRE_ROLES));
}
if (options.containsKey(DISABLE_ON_BEHALF_OF)) {
disableOnBehalfOf = Boolean.parseBoolean((String)options.get(DISABLE_ON_BEHALF_OF));
}
if (options.containsKey(DISABLE_CACHING)) {
disableCaching = Boolean.parseBoolean((String)options.get(DISABLE_CACHING));
}
if (options.containsKey(WSDL_LOCATION)) {
wsdlLocation = (String)options.get(WSDL_LOCATION);
}
if (options.containsKey(SERVICE_NAME)) {
serviceName = (String)options.get(SERVICE_NAME);
}
if (options.containsKey(ENDPOINT_NAME)) {
endpointName = (String)options.get(ENDPOINT_NAME);
}
if (options.containsKey(KEY_SIZE)) {
keySize = Integer.parseInt((String)options.get(KEY_SIZE));
}
if (options.containsKey(KEY_TYPE)) {
keyType = (String)options.get(KEY_TYPE);
}
if (options.containsKey(TOKEN_TYPE)) {
tokenType = (String)options.get(TOKEN_TYPE);
}
if (options.containsKey(WS_TRUST_NAMESPACE)) {
namespace = (String)options.get(WS_TRUST_NAMESPACE);
}
if (options.containsKey(CXF_SPRING_CFG)) {
cxfSpringCfg = (String)options.get(CXF_SPRING_CFG);
}
stsClientProperties.clear();
for (String s : SecurityConstants.ALL_PROPERTIES) {
if (options.containsKey(s)) {
stsClientProperties.put(s, options.get(s));
}
}
}
@Override
public boolean login() throws LoginException {
// Get username and password
Callback[] callbacks = new Callback[2];
callbacks[0] = new NameCallback("Username: ");
callbacks[1] = new PasswordCallback("Password: ", false);
try {
callbackHandler.handle(callbacks);
} catch (IOException ioException) {
throw new LoginException(ioException.getMessage());
} catch (UnsupportedCallbackException unsupportedCallbackException) {
throw new LoginException(unsupportedCallbackException.getMessage()
+ " not available to obtain information from user.");
}
String user = ((NameCallback) callbacks[0]).getName();
char[] tmpPassword = ((PasswordCallback) callbacks[1]).getPassword();
if (tmpPassword == null) {
tmpPassword = new char[0];
}
String password = new String(tmpPassword);
roles = new HashSet<>();
userPrincipal = null;
STSTokenValidator validator = new STSTokenValidator(true);
validator.setUseIssueBinding(requireRoles);
validator.setUseOnBehalfOf(!disableOnBehalfOf);
validator.setDisableCaching(!requireRoles || disableCaching);
// Authenticate token
try {
UsernameToken token = convertToToken(user, password);
Credential credential = new Credential();
credential.setUsernametoken(token);
RequestData data = new RequestData();
Message message = PhaseInterceptorChain.getCurrentMessage();
STSClient stsClient = configureSTSClient(message);
if (message != null) {
message.put(SecurityConstants.STS_CLIENT, stsClient);
data.setMsgContext(message);
} else {
TokenStore tokenStore = configureTokenStore();
validator.setStsClient(stsClient);
validator.setTokenStore(tokenStore);
}
credential = validator.validate(credential, data);
// Add user principal
userPrincipal = new SimplePrincipal(user);
// Add roles if a SAML Assertion was returned from the STS
roles.addAll(getRoles(message, credential));
} catch (Exception e) {
LOG.log(Level.INFO, "User " + user + " authentication failed", e);
throw new LoginException("User " + user + " authentication failed: " + e.getMessage());
}
return true;
}
private STSClient configureSTSClient(Message msg) throws BusException, EndpointException {
STSClient c = null;
if (cxfSpringCfg != null) {
SpringBusFactory bf = new SpringBusFactory();
URL busFile = Loader.getResource(cxfSpringCfg);
Bus bus = bf.createBus(busFile.toString());
SpringBusFactory.setDefaultBus(bus);
SpringBusFactory.setThreadDefaultBus(bus);
c = new STSClient(bus);
} else if (msg == null) {
Bus bus = BusFactory.getDefaultBus(true);
c = new STSClient(bus);
} else {
c = STSUtils.getClient(msg, "sts");
}
if (wsdlLocation != null) {
c.setWsdlLocation(wsdlLocation);
}
if (serviceName != null) {
c.setServiceName(serviceName);
}
if (endpointName != null) {
c.setEndpointName(endpointName);
}
if (keySize > 0) {
c.setKeySize(keySize);
}
if (keyType != null) {
c.setKeyType(keyType);
}
if (tokenType != null) {
c.setTokenType(tokenType);
}
if (namespace != null) {
c.setNamespace(namespace);
}
c.setProperties(stsClientProperties);
if (requireRoles && c.getClaimsCallbackHandler() == null) {
c.setClaimsCallbackHandler(new RoleClaimsCallbackHandler());
}
return c;
}
private TokenStore configureTokenStore() throws MalformedURLException {
if (TokenStoreFactory.isEhCacheInstalled()) {
String cfg = "cxf-ehcache.xml";
URL url = null;
if (url == null) {
url = ClassLoaderUtils.getResource(cfg, STSLoginModule.class);
}
if (url == null) {
url = new URL(cfg);
}
if (url != null) {
return new EHCacheTokenStore(TOKEN_STORE_KEY, BusFactory.getDefaultBus(), url);
}
}
return null;
}
private UsernameToken convertToToken(String username, String password)
throws Exception {
Document doc = DOMUtils.createDocument();
UsernameToken token = new UsernameToken(false, doc,
WSConstants.PASSWORD_TEXT);
token.setName(username);
token.setPassword(password);
return token;
}
private Set<Principal> getRoles(Message msg, Credential credential) {
SamlAssertionWrapper samlAssertion = credential.getTransformedToken();
if (samlAssertion == null) {
samlAssertion = credential.getSamlAssertion();
}
if (samlAssertion != null) {
String roleAttributeName = null;
if (msg != null) {
roleAttributeName =
(String)SecurityUtils.getSecurityPropertyValue(SecurityConstants.SAML_ROLE_ATTRIBUTENAME,
msg);
}
if (roleAttributeName == null || roleAttributeName.length() == 0) {
roleAttributeName = WSS4JInInterceptor.SAML_ROLE_ATTRIBUTENAME_DEFAULT;
}
ClaimCollection claims =
SAMLUtils.getClaims((SamlAssertionWrapper)samlAssertion);
return SAMLUtils.parseRolesFromClaims(claims, roleAttributeName, null);
}
return Collections.emptySet();
}
@Override
public boolean commit() throws LoginException {
if (userPrincipal == null) {
return false;
}
subject.getPrincipals().add(userPrincipal);
subject.getPrincipals().addAll(roles);
return true;
}
@Override
public boolean abort() throws LoginException {
return true;
}
@Override
public boolean logout() throws LoginException {
subject.getPrincipals().remove(userPrincipal);
subject.getPrincipals().removeAll(roles);
roles.clear();
userPrincipal = null;
return true;
}
}