/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.cxf.ws.security.wss4j; import java.net.URL; import java.security.Principal; import java.security.cert.Certificate; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Properties; import javax.security.auth.callback.CallbackHandler; import javax.xml.namespace.QName; import org.w3c.dom.Element; import org.apache.cxf.binding.soap.SoapMessage; import org.apache.cxf.common.classloader.ClassLoaderUtils; import org.apache.cxf.common.util.StringUtils; import org.apache.cxf.headers.Header; import org.apache.cxf.helpers.CastUtils; import org.apache.cxf.helpers.DOMUtils; import org.apache.cxf.interceptor.security.DefaultSecurityContext; import org.apache.cxf.rt.security.utils.SecurityUtils; import org.apache.cxf.security.SecurityContext; import org.apache.cxf.security.transport.TLSSessionInfo; import org.apache.cxf.ws.policy.AssertionInfo; import org.apache.cxf.ws.policy.AssertionInfoMap; import org.apache.cxf.ws.security.SecurityConstants; import org.apache.cxf.ws.security.policy.PolicyUtils; import org.apache.wss4j.common.crypto.Crypto; import org.apache.wss4j.common.crypto.CryptoFactory; import org.apache.wss4j.common.crypto.PasswordEncryptor; import org.apache.wss4j.common.ext.WSPasswordCallback; import org.apache.wss4j.common.ext.WSSecurityException; import org.apache.wss4j.common.saml.SAMLCallback; import org.apache.wss4j.common.saml.SAMLUtil; import org.apache.wss4j.common.saml.SamlAssertionWrapper; import org.apache.wss4j.common.saml.bean.Version; import org.apache.wss4j.dom.WSConstants; import org.apache.wss4j.dom.WSDocInfo; import org.apache.wss4j.dom.engine.WSSConfig; import org.apache.wss4j.dom.engine.WSSecurityEngineResult; import org.apache.wss4j.dom.handler.RequestData; import org.apache.wss4j.dom.handler.WSHandlerConstants; import org.apache.wss4j.dom.handler.WSHandlerResult; import org.apache.wss4j.dom.processor.SAMLTokenProcessor; import org.apache.wss4j.dom.saml.DOMSAMLUtil; import org.apache.wss4j.policy.SPConstants; import org.apache.wss4j.policy.model.AbstractToken; import org.apache.wss4j.policy.model.SamlToken; import org.apache.wss4j.policy.model.SamlToken.SamlTokenType; import org.opensaml.saml.common.SAMLVersion; /** * An interceptor to create and add a SAML token to the security header of an outbound * request, and to process a SAML Token on an inbound request. */ public class SamlTokenInterceptor extends AbstractTokenInterceptor { public SamlTokenInterceptor() { super(); } protected void processToken(SoapMessage message) { Header h = findSecurityHeader(message, false); if (h == null) { return; } Element el = (Element)h.getObject(); Element child = DOMUtils.getFirstElement(el); while (child != null) { if ("Assertion".equals(child.getLocalName()) && (WSConstants.SAML_NS.equals(child.getNamespaceURI()) || WSConstants.SAML2_NS.equals(child.getNamespaceURI()))) { try { List<WSSecurityEngineResult> samlResults = processToken(child, message); if (samlResults != null) { List<WSHandlerResult> results = CastUtils.cast((List<?>)message .get(WSHandlerConstants.RECV_RESULTS)); if (results == null) { results = new ArrayList<>(); message.put(WSHandlerConstants.RECV_RESULTS, results); } boolean signed = false; for (WSSecurityEngineResult result : samlResults) { SamlAssertionWrapper wrapper = (SamlAssertionWrapper)result.get(WSSecurityEngineResult.TAG_SAML_ASSERTION); if (wrapper.isSigned()) { signed = true; break; } } assertTokens(message, SPConstants.SAML_TOKEN, signed); Integer key = WSConstants.ST_UNSIGNED; if (signed) { key = WSConstants.ST_SIGNED; } WSHandlerResult rResult = new WSHandlerResult(null, samlResults, Collections.singletonMap(key, samlResults)); results.add(0, rResult); // Check version against policy AssertionInfoMap aim = message.get(AssertionInfoMap.class); for (AssertionInfo ai : PolicyUtils.getAllAssertionsByLocalname(aim, SPConstants.SAML_TOKEN)) { SamlToken samlToken = (SamlToken)ai.getAssertion(); for (WSSecurityEngineResult result : samlResults) { SamlAssertionWrapper assertionWrapper = (SamlAssertionWrapper)result.get(WSSecurityEngineResult.TAG_SAML_ASSERTION); if (!checkVersion(aim, samlToken, assertionWrapper)) { ai.setNotAsserted("Wrong SAML Version"); } TLSSessionInfo tlsInfo = message.get(TLSSessionInfo.class); Certificate[] tlsCerts = null; if (tlsInfo != null) { tlsCerts = tlsInfo.getPeerCertificates(); } if (!DOMSAMLUtil.checkHolderOfKey(assertionWrapper, null, tlsCerts)) { ai.setNotAsserted("Assertion fails holder-of-key requirements"); continue; } if (!DOMSAMLUtil.checkSenderVouches(assertionWrapper, tlsCerts, null, null)) { ai.setNotAsserted("Assertion fails sender-vouches requirements"); continue; } } } if (signed) { Principal principal = (Principal)samlResults.get(0).get(WSSecurityEngineResult.TAG_PRINCIPAL); SecurityContext sc = message.get(SecurityContext.class); if (sc == null || sc.getUserPrincipal() == null) { message.put(SecurityContext.class, new DefaultSecurityContext(principal, null)); } } } } catch (WSSecurityException ex) { throw WSS4JUtils.createSoapFault(message, message.getVersion(), ex); } } child = DOMUtils.getNextElement(child); } } private List<WSSecurityEngineResult> processToken(Element tokenElement, final SoapMessage message) throws WSSecurityException { RequestData data = new CXFRequestData(); Object o = SecurityUtils.getSecurityPropertyValue(SecurityConstants.CALLBACK_HANDLER, message); try { data.setCallbackHandler(SecurityUtils.getCallbackHandler(o)); } catch (Exception ex) { throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, ex); } data.setMsgContext(message); data.setWssConfig(WSSConfig.getNewInstance()); data.setSigVerCrypto(getCrypto(null, SecurityConstants.SIGNATURE_CRYPTO, SecurityConstants.SIGNATURE_PROPERTIES, message)); WSDocInfo wsDocInfo = new WSDocInfo(tokenElement.getOwnerDocument()); data.setWsDocInfo(wsDocInfo); SAMLTokenProcessor p = new SAMLTokenProcessor(); return p.handleToken(tokenElement, data); } protected AbstractToken assertTokens(SoapMessage message) { AssertionInfoMap aim = message.get(AssertionInfoMap.class); PolicyUtils.assertPolicy(aim, "WssSamlV11Token10"); PolicyUtils.assertPolicy(aim, "WssSamlV11Token11"); PolicyUtils.assertPolicy(aim, "WssSamlV20Token11"); return assertTokens(message, SPConstants.SAML_TOKEN, true); } protected void addToken(SoapMessage message) { WSSConfig.init(); SamlToken tok = (SamlToken)assertTokens(message); Header h = findSecurityHeader(message, true); try { SamlAssertionWrapper wrapper = addSamlToken(tok, message); if (wrapper == null) { AssertionInfoMap aim = message.get(AssertionInfoMap.class); Collection<AssertionInfo> ais = PolicyUtils.getAllAssertionsByLocalname(aim, SPConstants.SAML_TOKEN); for (AssertionInfo ai : ais) { if (ai.isAsserted()) { ai.setAsserted(false); } } return; } Element el = (Element)h.getObject(); el = (Element)DOMUtils.getDomElement(el); el.appendChild(wrapper.toDOM(el.getOwnerDocument())); } catch (WSSecurityException ex) { policyNotAsserted(tok, ex.getMessage(), message); } } private SamlAssertionWrapper addSamlToken( SamlToken token, SoapMessage message ) throws WSSecurityException { // // Get the SAML CallbackHandler // Object o = SecurityUtils.getSecurityPropertyValue(SecurityConstants.SAML_CALLBACK_HANDLER, message); CallbackHandler handler = null; if (o instanceof CallbackHandler) { handler = (CallbackHandler)o; } else if (o instanceof String) { try { handler = (CallbackHandler)ClassLoaderUtils .loadClass((String)o, this.getClass()).newInstance(); } catch (Exception e) { handler = null; } } if (handler == null) { return null; } AssertionInfoMap aim = message.get(AssertionInfoMap.class); SAMLCallback samlCallback = new SAMLCallback(); SamlTokenType tokenType = token.getSamlTokenType(); if (tokenType == SamlTokenType.WssSamlV11Token10 || tokenType == SamlTokenType.WssSamlV11Token11) { samlCallback.setSamlVersion(Version.SAML_11); PolicyUtils.assertPolicy(aim, "WssSamlV11Token10"); PolicyUtils.assertPolicy(aim, "WssSamlV11Token11"); } else if (tokenType == SamlTokenType.WssSamlV20Token11) { samlCallback.setSamlVersion(Version.SAML_20); PolicyUtils.assertPolicy(aim, "WssSamlV20Token11"); } SAMLUtil.doSAMLCallback(handler, samlCallback); SamlAssertionWrapper assertion = new SamlAssertionWrapper(samlCallback); if (samlCallback.isSignAssertion()) { String issuerName = samlCallback.getIssuerKeyName(); if (issuerName == null) { String userNameKey = SecurityConstants.SIGNATURE_USERNAME; issuerName = (String)SecurityUtils.getSecurityPropertyValue(userNameKey, message); } String password = samlCallback.getIssuerKeyPassword(); if (password == null) { password = (String)SecurityUtils.getSecurityPropertyValue(SecurityConstants.PASSWORD, message); if (StringUtils.isEmpty(password)) { password = getPassword(issuerName, token, WSPasswordCallback.SIGNATURE, message); } } Crypto crypto = samlCallback.getIssuerCrypto(); if (crypto == null) { crypto = getCrypto(token, SecurityConstants.SIGNATURE_CRYPTO, SecurityConstants.SIGNATURE_PROPERTIES, message); } assertion.signAssertion( issuerName, password, crypto, samlCallback.isSendKeyValue(), samlCallback.getCanonicalizationAlgorithm(), samlCallback.getSignatureAlgorithm() ); } return assertion; } private Crypto getCrypto( SamlToken samlToken, String cryptoKey, String propKey, SoapMessage message ) throws WSSecurityException { Crypto crypto = (Crypto)SecurityUtils.getSecurityPropertyValue(cryptoKey, message); if (crypto != null) { return crypto; } Object o = SecurityUtils.getSecurityPropertyValue(propKey, message); if (o == null) { return null; } URL propsURL = SecurityUtils.loadResource(message, o); Properties properties = WSS4JUtils.getProps(o, propsURL); if (properties != null) { PasswordEncryptor passwordEncryptor = WSS4JUtils.getPasswordEncryptor(message); crypto = CryptoFactory.getInstance(properties, this.getClass().getClassLoader(), passwordEncryptor); } return crypto; } /** * Check the policy version against the received assertion */ private boolean checkVersion( AssertionInfoMap aim, SamlToken samlToken, SamlAssertionWrapper assertionWrapper ) { SamlTokenType tokenType = samlToken.getSamlTokenType(); if ((tokenType == SamlTokenType.WssSamlV11Token10 || tokenType == SamlTokenType.WssSamlV11Token11) && assertionWrapper.getSamlVersion() != SAMLVersion.VERSION_11) { return false; } else if (tokenType == SamlTokenType.WssSamlV20Token11 && assertionWrapper.getSamlVersion() != SAMLVersion.VERSION_20) { return false; } PolicyUtils.assertPolicy(aim, new QName(samlToken.getVersion().getNamespace(), tokenType.name())); return true; } }