/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.ws.security.wss4j;
import java.net.URL;
import java.security.Principal;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import javax.security.auth.callback.CallbackHandler;
import javax.xml.namespace.QName;
import org.w3c.dom.Element;
import org.apache.cxf.binding.soap.SoapMessage;
import org.apache.cxf.common.classloader.ClassLoaderUtils;
import org.apache.cxf.common.util.StringUtils;
import org.apache.cxf.headers.Header;
import org.apache.cxf.helpers.CastUtils;
import org.apache.cxf.helpers.DOMUtils;
import org.apache.cxf.interceptor.security.DefaultSecurityContext;
import org.apache.cxf.rt.security.utils.SecurityUtils;
import org.apache.cxf.security.SecurityContext;
import org.apache.cxf.security.transport.TLSSessionInfo;
import org.apache.cxf.ws.policy.AssertionInfo;
import org.apache.cxf.ws.policy.AssertionInfoMap;
import org.apache.cxf.ws.security.SecurityConstants;
import org.apache.cxf.ws.security.policy.PolicyUtils;
import org.apache.wss4j.common.crypto.Crypto;
import org.apache.wss4j.common.crypto.CryptoFactory;
import org.apache.wss4j.common.crypto.PasswordEncryptor;
import org.apache.wss4j.common.ext.WSPasswordCallback;
import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.wss4j.common.saml.SAMLCallback;
import org.apache.wss4j.common.saml.SAMLUtil;
import org.apache.wss4j.common.saml.SamlAssertionWrapper;
import org.apache.wss4j.common.saml.bean.Version;
import org.apache.wss4j.dom.WSConstants;
import org.apache.wss4j.dom.WSDocInfo;
import org.apache.wss4j.dom.engine.WSSConfig;
import org.apache.wss4j.dom.engine.WSSecurityEngineResult;
import org.apache.wss4j.dom.handler.RequestData;
import org.apache.wss4j.dom.handler.WSHandlerConstants;
import org.apache.wss4j.dom.handler.WSHandlerResult;
import org.apache.wss4j.dom.processor.SAMLTokenProcessor;
import org.apache.wss4j.dom.saml.DOMSAMLUtil;
import org.apache.wss4j.policy.SPConstants;
import org.apache.wss4j.policy.model.AbstractToken;
import org.apache.wss4j.policy.model.SamlToken;
import org.apache.wss4j.policy.model.SamlToken.SamlTokenType;
import org.opensaml.saml.common.SAMLVersion;
/**
* An interceptor to create and add a SAML token to the security header of an outbound
* request, and to process a SAML Token on an inbound request.
*/
public class SamlTokenInterceptor extends AbstractTokenInterceptor {
public SamlTokenInterceptor() {
super();
}
protected void processToken(SoapMessage message) {
Header h = findSecurityHeader(message, false);
if (h == null) {
return;
}
Element el = (Element)h.getObject();
Element child = DOMUtils.getFirstElement(el);
while (child != null) {
if ("Assertion".equals(child.getLocalName())
&& (WSConstants.SAML_NS.equals(child.getNamespaceURI())
|| WSConstants.SAML2_NS.equals(child.getNamespaceURI()))) {
try {
List<WSSecurityEngineResult> samlResults = processToken(child, message);
if (samlResults != null) {
List<WSHandlerResult> results = CastUtils.cast((List<?>)message
.get(WSHandlerConstants.RECV_RESULTS));
if (results == null) {
results = new ArrayList<>();
message.put(WSHandlerConstants.RECV_RESULTS, results);
}
boolean signed = false;
for (WSSecurityEngineResult result : samlResults) {
SamlAssertionWrapper wrapper =
(SamlAssertionWrapper)result.get(WSSecurityEngineResult.TAG_SAML_ASSERTION);
if (wrapper.isSigned()) {
signed = true;
break;
}
}
assertTokens(message, SPConstants.SAML_TOKEN, signed);
Integer key = WSConstants.ST_UNSIGNED;
if (signed) {
key = WSConstants.ST_SIGNED;
}
WSHandlerResult rResult =
new WSHandlerResult(null, samlResults,
Collections.singletonMap(key, samlResults));
results.add(0, rResult);
// Check version against policy
AssertionInfoMap aim = message.get(AssertionInfoMap.class);
for (AssertionInfo ai
: PolicyUtils.getAllAssertionsByLocalname(aim, SPConstants.SAML_TOKEN)) {
SamlToken samlToken = (SamlToken)ai.getAssertion();
for (WSSecurityEngineResult result : samlResults) {
SamlAssertionWrapper assertionWrapper =
(SamlAssertionWrapper)result.get(WSSecurityEngineResult.TAG_SAML_ASSERTION);
if (!checkVersion(aim, samlToken, assertionWrapper)) {
ai.setNotAsserted("Wrong SAML Version");
}
TLSSessionInfo tlsInfo = message.get(TLSSessionInfo.class);
Certificate[] tlsCerts = null;
if (tlsInfo != null) {
tlsCerts = tlsInfo.getPeerCertificates();
}
if (!DOMSAMLUtil.checkHolderOfKey(assertionWrapper, null, tlsCerts)) {
ai.setNotAsserted("Assertion fails holder-of-key requirements");
continue;
}
if (!DOMSAMLUtil.checkSenderVouches(assertionWrapper, tlsCerts, null, null)) {
ai.setNotAsserted("Assertion fails sender-vouches requirements");
continue;
}
}
}
if (signed) {
Principal principal =
(Principal)samlResults.get(0).get(WSSecurityEngineResult.TAG_PRINCIPAL);
SecurityContext sc = message.get(SecurityContext.class);
if (sc == null || sc.getUserPrincipal() == null) {
message.put(SecurityContext.class, new DefaultSecurityContext(principal, null));
}
}
}
} catch (WSSecurityException ex) {
throw WSS4JUtils.createSoapFault(message, message.getVersion(), ex);
}
}
child = DOMUtils.getNextElement(child);
}
}
private List<WSSecurityEngineResult> processToken(Element tokenElement, final SoapMessage message)
throws WSSecurityException {
RequestData data = new CXFRequestData();
Object o = SecurityUtils.getSecurityPropertyValue(SecurityConstants.CALLBACK_HANDLER, message);
try {
data.setCallbackHandler(SecurityUtils.getCallbackHandler(o));
} catch (Exception ex) {
throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, ex);
}
data.setMsgContext(message);
data.setWssConfig(WSSConfig.getNewInstance());
data.setSigVerCrypto(getCrypto(null, SecurityConstants.SIGNATURE_CRYPTO,
SecurityConstants.SIGNATURE_PROPERTIES, message));
WSDocInfo wsDocInfo = new WSDocInfo(tokenElement.getOwnerDocument());
data.setWsDocInfo(wsDocInfo);
SAMLTokenProcessor p = new SAMLTokenProcessor();
return p.handleToken(tokenElement, data);
}
protected AbstractToken assertTokens(SoapMessage message) {
AssertionInfoMap aim = message.get(AssertionInfoMap.class);
PolicyUtils.assertPolicy(aim, "WssSamlV11Token10");
PolicyUtils.assertPolicy(aim, "WssSamlV11Token11");
PolicyUtils.assertPolicy(aim, "WssSamlV20Token11");
return assertTokens(message, SPConstants.SAML_TOKEN, true);
}
protected void addToken(SoapMessage message) {
WSSConfig.init();
SamlToken tok = (SamlToken)assertTokens(message);
Header h = findSecurityHeader(message, true);
try {
SamlAssertionWrapper wrapper = addSamlToken(tok, message);
if (wrapper == null) {
AssertionInfoMap aim = message.get(AssertionInfoMap.class);
Collection<AssertionInfo> ais =
PolicyUtils.getAllAssertionsByLocalname(aim, SPConstants.SAML_TOKEN);
for (AssertionInfo ai : ais) {
if (ai.isAsserted()) {
ai.setAsserted(false);
}
}
return;
}
Element el = (Element)h.getObject();
el = (Element)DOMUtils.getDomElement(el);
el.appendChild(wrapper.toDOM(el.getOwnerDocument()));
} catch (WSSecurityException ex) {
policyNotAsserted(tok, ex.getMessage(), message);
}
}
private SamlAssertionWrapper addSamlToken(
SamlToken token, SoapMessage message
) throws WSSecurityException {
//
// Get the SAML CallbackHandler
//
Object o =
SecurityUtils.getSecurityPropertyValue(SecurityConstants.SAML_CALLBACK_HANDLER, message);
CallbackHandler handler = null;
if (o instanceof CallbackHandler) {
handler = (CallbackHandler)o;
} else if (o instanceof String) {
try {
handler = (CallbackHandler)ClassLoaderUtils
.loadClass((String)o, this.getClass()).newInstance();
} catch (Exception e) {
handler = null;
}
}
if (handler == null) {
return null;
}
AssertionInfoMap aim = message.get(AssertionInfoMap.class);
SAMLCallback samlCallback = new SAMLCallback();
SamlTokenType tokenType = token.getSamlTokenType();
if (tokenType == SamlTokenType.WssSamlV11Token10 || tokenType == SamlTokenType.WssSamlV11Token11) {
samlCallback.setSamlVersion(Version.SAML_11);
PolicyUtils.assertPolicy(aim, "WssSamlV11Token10");
PolicyUtils.assertPolicy(aim, "WssSamlV11Token11");
} else if (tokenType == SamlTokenType.WssSamlV20Token11) {
samlCallback.setSamlVersion(Version.SAML_20);
PolicyUtils.assertPolicy(aim, "WssSamlV20Token11");
}
SAMLUtil.doSAMLCallback(handler, samlCallback);
SamlAssertionWrapper assertion = new SamlAssertionWrapper(samlCallback);
if (samlCallback.isSignAssertion()) {
String issuerName = samlCallback.getIssuerKeyName();
if (issuerName == null) {
String userNameKey = SecurityConstants.SIGNATURE_USERNAME;
issuerName = (String)SecurityUtils.getSecurityPropertyValue(userNameKey, message);
}
String password = samlCallback.getIssuerKeyPassword();
if (password == null) {
password =
(String)SecurityUtils.getSecurityPropertyValue(SecurityConstants.PASSWORD, message);
if (StringUtils.isEmpty(password)) {
password =
getPassword(issuerName, token, WSPasswordCallback.SIGNATURE, message);
}
}
Crypto crypto = samlCallback.getIssuerCrypto();
if (crypto == null) {
crypto =
getCrypto(token, SecurityConstants.SIGNATURE_CRYPTO,
SecurityConstants.SIGNATURE_PROPERTIES, message);
}
assertion.signAssertion(
issuerName,
password,
crypto,
samlCallback.isSendKeyValue(),
samlCallback.getCanonicalizationAlgorithm(),
samlCallback.getSignatureAlgorithm()
);
}
return assertion;
}
private Crypto getCrypto(
SamlToken samlToken,
String cryptoKey,
String propKey,
SoapMessage message
) throws WSSecurityException {
Crypto crypto = (Crypto)SecurityUtils.getSecurityPropertyValue(cryptoKey, message);
if (crypto != null) {
return crypto;
}
Object o = SecurityUtils.getSecurityPropertyValue(propKey, message);
if (o == null) {
return null;
}
URL propsURL = SecurityUtils.loadResource(message, o);
Properties properties = WSS4JUtils.getProps(o, propsURL);
if (properties != null) {
PasswordEncryptor passwordEncryptor = WSS4JUtils.getPasswordEncryptor(message);
crypto = CryptoFactory.getInstance(properties, this.getClass().getClassLoader(), passwordEncryptor);
}
return crypto;
}
/**
* Check the policy version against the received assertion
*/
private boolean checkVersion(
AssertionInfoMap aim,
SamlToken samlToken,
SamlAssertionWrapper assertionWrapper
) {
SamlTokenType tokenType = samlToken.getSamlTokenType();
if ((tokenType == SamlTokenType.WssSamlV11Token10
|| tokenType == SamlTokenType.WssSamlV11Token11)
&& assertionWrapper.getSamlVersion() != SAMLVersion.VERSION_11) {
return false;
} else if (tokenType == SamlTokenType.WssSamlV20Token11
&& assertionWrapper.getSamlVersion() != SAMLVersion.VERSION_20) {
return false;
}
PolicyUtils.assertPolicy(aim, new QName(samlToken.getVersion().getNamespace(), tokenType.name()));
return true;
}
}