/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.cxf.rs.security.saml.sso; import java.io.ByteArrayInputStream; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.logging.Level; import java.util.logging.Logger; import javax.crypto.Cipher; import javax.crypto.SecretKey; import javax.security.auth.callback.CallbackHandler; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.NodeList; import org.apache.cxf.common.logging.LogUtils; import org.apache.cxf.common.util.Base64Exception; import org.apache.cxf.common.util.Base64Utility; import org.apache.cxf.rs.security.common.RSSecurityUtils; import org.apache.cxf.rs.security.xml.EncryptionUtils; import org.apache.cxf.staxutils.StaxUtils; import org.apache.wss4j.common.WSS4JConstants; import org.apache.wss4j.common.crypto.Crypto; import org.apache.wss4j.common.ext.WSSecurityException; import org.apache.wss4j.common.saml.SAMLKeyInfo; import org.apache.wss4j.common.saml.SAMLUtil; import org.apache.wss4j.common.saml.SamlAssertionWrapper; import org.apache.wss4j.common.util.KeyUtils; import org.apache.wss4j.dom.WSDocInfo; import org.apache.wss4j.dom.engine.WSSConfig; import org.apache.wss4j.dom.handler.RequestData; import org.apache.wss4j.dom.saml.WSSSAMLKeyInfoProcessor; import org.apache.wss4j.dom.validate.Credential; import org.apache.wss4j.dom.validate.SignatureTrustValidator; import org.apache.wss4j.dom.validate.Validator; import org.apache.xml.security.encryption.XMLCipher; import org.apache.xml.security.encryption.XMLEncryptionException; import org.apache.xml.security.utils.Constants; import org.joda.time.DateTime; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; import org.opensaml.security.credential.BasicCredential; import org.opensaml.security.x509.BasicX509Credential; import org.opensaml.xmlsec.encryption.EncryptedData; import org.opensaml.xmlsec.signature.KeyInfo; import org.opensaml.xmlsec.signature.Signature; import org.opensaml.xmlsec.signature.support.SignatureException; import org.opensaml.xmlsec.signature.support.SignatureValidator; /** * Validate a SAML (1.1 or 2.0) Protocol Response. It validates the Response against the specs, * the signature of the Response (if it exists), and any internal Assertion stored in the Response * - including any signature. It validates the status code of the Response as well. */ public class SAMLProtocolResponseValidator { public static final String SAML2_STATUSCODE_SUCCESS = "urn:oasis:names:tc:SAML:2.0:status:Success"; public static final String SAML1_STATUSCODE_SUCCESS = "Success"; private static final Logger LOG = LogUtils.getL7dLogger(SAMLProtocolResponseValidator.class); private Validator signatureValidator = new SignatureTrustValidator(); private boolean keyInfoMustBeAvailable = true; /** * The time in seconds in the future within which the NotBefore time of an incoming * Assertion is valid. The default is 60 seconds. */ private int futureTTL = 60; /** * Validate a SAML 2 Protocol Response * @param samlResponse * @param sigCrypto * @param callbackHandler * @throws WSSecurityException */ public void validateSamlResponse( org.opensaml.saml.saml2.core.Response samlResponse, Crypto sigCrypto, CallbackHandler callbackHandler ) throws WSSecurityException { // Check the Status Code if (samlResponse.getStatus() == null || samlResponse.getStatus().getStatusCode() == null) { LOG.fine("Either the SAML Response Status or StatusCode is null"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } if (!SAML2_STATUSCODE_SUCCESS.equals(samlResponse.getStatus().getStatusCode().getValue())) { LOG.fine( "SAML Status code of " + samlResponse.getStatus().getStatusCode().getValue() + "does not equal " + SAML2_STATUSCODE_SUCCESS ); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } if (samlResponse.getIssueInstant() != null) { DateTime currentTime = new DateTime(); currentTime = currentTime.plusSeconds(futureTTL); if (samlResponse.getIssueInstant().isAfter(currentTime)) { LOG.fine("SAML Response IssueInstant not met"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } if (SAMLVersion.VERSION_20 != samlResponse.getVersion()) { LOG.fine( "SAML Version of " + samlResponse.getVersion() + "does not equal " + SAMLVersion.VERSION_20 ); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } validateResponseSignature(samlResponse, sigCrypto, callbackHandler); Document doc = samlResponse.getDOM().getOwnerDocument(); // Decrypt any encrypted Assertions and add them to the Response (note that this will break any // signature on the Response) for (org.opensaml.saml.saml2.core.EncryptedAssertion assertion : samlResponse.getEncryptedAssertions()) { Element decAssertion = decryptAssertion(assertion, sigCrypto, callbackHandler); SamlAssertionWrapper wrapper = new SamlAssertionWrapper(decAssertion); samlResponse.getAssertions().add(wrapper.getSaml2()); } // Validate Assertions for (org.opensaml.saml.saml2.core.Assertion assertion : samlResponse.getAssertions()) { SamlAssertionWrapper wrapper = new SamlAssertionWrapper(assertion); validateAssertion(wrapper, sigCrypto, callbackHandler, doc, samlResponse.isSigned()); } } /** * Validate a SAML 1.1 Protocol Response * @param samlResponse * @param sigCrypto * @param callbackHandler * @throws WSSecurityException */ public void validateSamlResponse( org.opensaml.saml.saml1.core.Response samlResponse, Crypto sigCrypto, CallbackHandler callbackHandler ) throws WSSecurityException { // Check the Status Code if (samlResponse.getStatus() == null || samlResponse.getStatus().getStatusCode() == null || samlResponse.getStatus().getStatusCode().getValue() == null) { LOG.fine("Either the SAML Response Status or StatusCode is null"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } String statusValue = samlResponse.getStatus().getStatusCode().getValue().getLocalPart(); if (!SAML1_STATUSCODE_SUCCESS.equals(statusValue)) { LOG.fine( "SAML Status code of " + samlResponse.getStatus().getStatusCode().getValue() + "does not equal " + SAML1_STATUSCODE_SUCCESS ); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } if (samlResponse.getIssueInstant() != null) { DateTime currentTime = new DateTime(); currentTime = currentTime.plusSeconds(futureTTL); if (samlResponse.getIssueInstant().isAfter(currentTime)) { LOG.fine("SAML Response IssueInstant not met"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } if (SAMLVersion.VERSION_11 != samlResponse.getVersion()) { LOG.fine( "SAML Version of " + samlResponse.getVersion() + "does not equal " + SAMLVersion.VERSION_11 ); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } validateResponseSignature(samlResponse, sigCrypto, callbackHandler); // Validate Assertions for (org.opensaml.saml.saml1.core.Assertion assertion : samlResponse.getAssertions()) { SamlAssertionWrapper wrapper = new SamlAssertionWrapper(assertion); validateAssertion( wrapper, sigCrypto, callbackHandler, samlResponse.getDOM().getOwnerDocument(), samlResponse.isSigned() ); } } /** * Validate the Response signature (if it exists) */ private void validateResponseSignature( org.opensaml.saml.saml2.core.Response samlResponse, Crypto sigCrypto, CallbackHandler callbackHandler ) throws WSSecurityException { if (!samlResponse.isSigned()) { return; } validateResponseSignature( samlResponse.getSignature(), samlResponse.getDOM().getOwnerDocument(), sigCrypto, callbackHandler ); } /** * Validate the Response signature (if it exists) */ private void validateResponseSignature( org.opensaml.saml.saml1.core.Response samlResponse, Crypto sigCrypto, CallbackHandler callbackHandler ) throws WSSecurityException { if (!samlResponse.isSigned()) { return; } validateResponseSignature( samlResponse.getSignature(), samlResponse.getDOM().getOwnerDocument(), sigCrypto, callbackHandler ); } /** * Validate the response signature */ private void validateResponseSignature( Signature signature, Document doc, Crypto sigCrypto, CallbackHandler callbackHandler ) throws WSSecurityException { RequestData requestData = new RequestData(); requestData.setSigVerCrypto(sigCrypto); WSSConfig wssConfig = WSSConfig.getNewInstance(); requestData.setWssConfig(wssConfig); requestData.setCallbackHandler(callbackHandler); requestData.setWsDocInfo(new WSDocInfo(doc)); SAMLKeyInfo samlKeyInfo = null; KeyInfo keyInfo = signature.getKeyInfo(); if (keyInfo != null) { try { samlKeyInfo = SAMLUtil.getCredentialFromKeyInfo( keyInfo.getDOM(), new WSSSAMLKeyInfoProcessor(requestData), sigCrypto ); } catch (WSSecurityException ex) { LOG.log(Level.FINE, "Error in getting KeyInfo from SAML Response: " + ex.getMessage(), ex); throw ex; } } else if (!keyInfoMustBeAvailable) { samlKeyInfo = createKeyInfoFromDefaultAlias(sigCrypto); } if (samlKeyInfo == null) { LOG.fine("No KeyInfo supplied in the SAMLResponse signature"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } // Validate Signature against profiles validateSignatureAgainstProfiles(signature, samlKeyInfo); // Now verify trust on the signature Credential trustCredential = new Credential(); trustCredential.setPublicKey(samlKeyInfo.getPublicKey()); trustCredential.setCertificates(samlKeyInfo.getCerts()); try { signatureValidator.validate(trustCredential, requestData); } catch (WSSecurityException e) { LOG.log(Level.FINE, "Error in validating signature on SAML Response: " + e.getMessage(), e); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } protected SAMLKeyInfo createKeyInfoFromDefaultAlias(Crypto sigCrypto) throws WSSecurityException { try { X509Certificate[] certs = RSSecurityUtils.getCertificates(sigCrypto, sigCrypto.getDefaultX509Identifier()); SAMLKeyInfo samlKeyInfo = new SAMLKeyInfo(new X509Certificate[]{certs[0]}); samlKeyInfo.setPublicKey(certs[0].getPublicKey()); return samlKeyInfo; } catch (Exception ex) { LOG.log(Level.FINE, "Error in loading the certificates: " + ex.getMessage(), ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILED_SIGNATURE, ex); } } /** * Validate a signature against the profiles */ private void validateSignatureAgainstProfiles( Signature signature, SAMLKeyInfo samlKeyInfo ) throws WSSecurityException { // Validate Signature against profiles SAMLSignatureProfileValidator validator = new SAMLSignatureProfileValidator(); try { validator.validate(signature); } catch (SignatureException ex) { LOG.log(Level.FINE, "Error in validating the SAML Signature: " + ex.getMessage(), ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } BasicCredential credential = null; if (samlKeyInfo.getCerts() != null) { credential = new BasicX509Credential(samlKeyInfo.getCerts()[0]); } else if (samlKeyInfo.getPublicKey() != null) { credential = new BasicCredential(samlKeyInfo.getPublicKey()); } else { LOG.fine("Can't get X509Certificate or PublicKey to verify signature"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } try { SignatureValidator.validate(signature, credential); } catch (SignatureException ex) { LOG.log(Level.FINE, "Error in validating the SAML Signature: " + ex.getMessage(), ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } /** * Validate an internal Assertion */ private void validateAssertion( SamlAssertionWrapper assertion, Crypto sigCrypto, CallbackHandler callbackHandler, Document doc, boolean signedResponse ) throws WSSecurityException { Credential credential = new Credential(); credential.setSamlAssertion(assertion); RequestData requestData = new RequestData(); requestData.setSigVerCrypto(sigCrypto); WSSConfig wssConfig = WSSConfig.getNewInstance(); requestData.setWssConfig(wssConfig); requestData.setCallbackHandler(callbackHandler); if (assertion.isSigned()) { if (assertion.getSaml1() != null) { assertion.getSaml1().getDOM().setIdAttributeNS(null, "AssertionID", true); } else { assertion.getSaml2().getDOM().setIdAttributeNS(null, "ID", true); } // Verify the signature try { Signature sig = assertion.getSignature(); WSDocInfo docInfo = new WSDocInfo(sig.getDOM().getOwnerDocument()); requestData.setWsDocInfo(docInfo); SAMLKeyInfo samlKeyInfo = null; KeyInfo keyInfo = sig.getKeyInfo(); if (keyInfo != null) { samlKeyInfo = SAMLUtil.getCredentialFromKeyInfo( keyInfo.getDOM(), new WSSSAMLKeyInfoProcessor(requestData), sigCrypto ); } else if (!keyInfoMustBeAvailable) { samlKeyInfo = createKeyInfoFromDefaultAlias(sigCrypto); } if (samlKeyInfo == null) { LOG.fine("No KeyInfo supplied in the SAMLResponse assertion signature"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } assertion.verifySignature(samlKeyInfo); assertion.parseSubject( new WSSSAMLKeyInfoProcessor(requestData), requestData.getSigVerCrypto(), requestData.getCallbackHandler() ); } catch (WSSecurityException e) { LOG.log(Level.FINE, "Assertion failed signature validation", e); throw e; } } // Validate the Assertion & verify trust in the signature try { SamlSSOAssertionValidator assertionValidator = new SamlSSOAssertionValidator(signedResponse); assertionValidator.validate(credential, requestData); } catch (WSSecurityException ex) { LOG.log(Level.FINE, "Assertion validation failed: " + ex.getMessage(), ex); throw ex; } } private Element decryptAssertion( org.opensaml.saml.saml2.core.EncryptedAssertion assertion, Crypto sigCrypto, CallbackHandler callbackHandler ) throws WSSecurityException { EncryptedData encryptedData = assertion.getEncryptedData(); Element encryptedDataDOM = encryptedData.getDOM(); Element encKeyElement = getNode(assertion.getDOM(), WSS4JConstants.ENC_NS, "EncryptedKey", 0); if (encKeyElement == null) { encKeyElement = getNode(encryptedDataDOM, WSS4JConstants.ENC_NS, "EncryptedKey", 0); } if (encKeyElement == null) { LOG.log(Level.FINE, "EncryptedKey element is not available"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } X509Certificate cert = loadCertificate(sigCrypto, encKeyElement); if (cert == null) { LOG.fine("X509Certificate cannot be retrieved from EncryptedKey element"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } // now start decrypting String keyEncAlgo = getEncodingMethodAlgorithm(encKeyElement); String digestAlgo = getDigestMethodAlgorithm(encKeyElement); Element cipherValue = getNode(encKeyElement, WSS4JConstants.ENC_NS, "CipherValue", 0); if (cipherValue == null) { LOG.fine("CipherValue element is not available"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } if (callbackHandler == null) { LOG.fine("A CallbackHandler must be configured to decrypt encrypted Assertions"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } PrivateKey key = null; try { key = sigCrypto.getPrivateKey(cert, callbackHandler); } catch (Exception ex) { LOG.log(Level.FINE, "Encrypted key can not be decrypted", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } Cipher cipher = EncryptionUtils.initCipherWithKey(keyEncAlgo, digestAlgo, Cipher.DECRYPT_MODE, key); byte[] decryptedBytes = null; try { byte[] encryptedBytes = Base64Utility.decode(cipherValue.getTextContent().trim()); decryptedBytes = cipher.doFinal(encryptedBytes); } catch (Base64Exception ex) { LOG.log(Level.FINE, "Base64 decoding has failed", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } catch (Exception ex) { LOG.log(Level.FINE, "Encrypted key can not be decrypted", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } String symKeyAlgo = getEncodingMethodAlgorithm(encryptedDataDOM); byte[] decryptedPayload = null; try { decryptedPayload = decryptPayload(encryptedDataDOM, decryptedBytes, symKeyAlgo); } catch (Exception ex) { LOG.log(Level.FINE, "Payload can not be decrypted", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } Document payloadDoc = null; try { payloadDoc = StaxUtils.read(new InputStreamReader(new ByteArrayInputStream(decryptedPayload), StandardCharsets.UTF_8)); return payloadDoc.getDocumentElement(); } catch (Exception ex) { LOG.log(Level.FINE, "Payload document can not be created", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } private Element getNode(Element parent, String ns, String name, int index) { NodeList list = parent.getElementsByTagNameNS(ns, name); if (list != null && list.getLength() >= index + 1) { return (Element)list.item(index); } return null; } private X509Certificate loadCertificate(Crypto crypto, Element encKeyElement) throws WSSecurityException { Element certNode = getNode(encKeyElement, Constants.SignatureSpecNS, "X509Certificate", 0); if (certNode != null) { try { return RSSecurityUtils.loadX509Certificate(crypto, certNode); } catch (Exception ex) { LOG.log(Level.FINE, "X509Certificate can not be created", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } certNode = getNode(encKeyElement, Constants.SignatureSpecNS, "X509IssuerSerial", 0); if (certNode != null) { try { return RSSecurityUtils.loadX509IssuerSerial(crypto, certNode); } catch (Exception ex) { LOG.log(Level.FINE, "X509Certificate can not be created", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } if (crypto.getDefaultX509Identifier() != null) { try { X509Certificate[] certs = RSSecurityUtils.getCertificates(crypto, crypto.getDefaultX509Identifier()); if (certs.length > 0) { return certs[0]; } } catch (Exception ex) { LOG.log(Level.FINE, "X509Certificate can not be created", ex); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } } return null; } private String getEncodingMethodAlgorithm(Element parent) throws WSSecurityException { Element encMethod = getNode(parent, WSS4JConstants.ENC_NS, "EncryptionMethod", 0); if (encMethod == null) { LOG.fine("EncryptionMethod element is not available"); throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "invalidSAMLsecurity"); } return encMethod.getAttribute("Algorithm"); } private String getDigestMethodAlgorithm(Element parent) { Element encMethod = getNode(parent, WSS4JConstants.ENC_NS, "EncryptionMethod", 0); if (encMethod != null) { Element digestMethod = getNode(encMethod, WSS4JConstants.SIG_NS, "DigestMethod", 0); if (digestMethod != null) { return digestMethod.getAttributeNS(null, "Algorithm"); } } return null; } private byte[] decryptPayload( Element root, byte[] secretKeyBytes, String symEncAlgo ) throws WSSecurityException { SecretKey key = KeyUtils.prepareSecretKey(symEncAlgo, secretKeyBytes); try { XMLCipher xmlCipher = EncryptionUtils.initXMLCipher(symEncAlgo, XMLCipher.DECRYPT_MODE, key); return xmlCipher.decryptToByteArray(root); } catch (XMLEncryptionException ex) { throw new WSSecurityException(WSSecurityException.ErrorCode.UNSUPPORTED_ALGORITHM, ex); } } public void setKeyInfoMustBeAvailable(boolean keyInfoMustBeAvailable) { this.keyInfoMustBeAvailable = keyInfoMustBeAvailable; } public int getFutureTTL() { return futureTTL; } public void setFutureTTL(int futureTTL) { this.futureTTL = futureTTL; } }