/*
* Atricore IDBus
*
* Copyright (c) 2009, Atricore Inc.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.atricore.idbus.capabilities.sso.support.core.signature;
import oasis.names.tc.saml._2_0.assertion.AssertionType;
import oasis.names.tc.saml._2_0.metadata.KeyDescriptorType;
import oasis.names.tc.saml._2_0.metadata.KeyTypes;
import oasis.names.tc.saml._2_0.metadata.RoleDescriptorType;
import oasis.names.tc.saml._2_0.protocol.*;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.atricore.idbus.capabilities.sso.support.SAMLR11Constants;
import org.atricore.idbus.capabilities.sso.support.SAMLR2Constants;
import org.atricore.idbus.capabilities.sso.support.core.SSOKeyResolver;
import org.atricore.idbus.capabilities.sso.support.core.SSOKeyResolverException;
import org.atricore.idbus.capabilities.sso.support.core.util.NamespaceFilterXMLStreamWriter;
import org.atricore.idbus.capabilities.sso.support.core.util.XmlUtils;
import org.w3._2000._09.xmldsig_.X509DataType;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
import sun.security.x509.KeyUsageExtension;
import javax.xml.bind.*;
import javax.xml.bind.annotation.XmlType;
import javax.xml.crypto.*;
import javax.xml.crypto.MarshalException;
import javax.xml.crypto.dsig.*;
import javax.xml.crypto.dsig.dom.DOMSignContext;
import javax.xml.crypto.dsig.dom.DOMValidateContext;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyInfoFactory;
import javax.xml.crypto.dsig.keyinfo.KeyValue;
import javax.xml.crypto.dsig.keyinfo.X509Data;
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
import javax.xml.crypto.dsig.spec.TransformParameterSpec;
import javax.xml.namespace.NamespaceContext;
import javax.xml.namespace.QName;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamWriter;
import javax.xml.xpath.*;
import java.io.*;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.security.*;
import java.security.acl.NotOwnerException;
import java.security.cert.Certificate;
import java.security.cert.*;
import java.util.*;
/**
* This will sign and verify saml2 identity artifact (assertion, requet, response) signatures usign a JSR 105 Provider.
* <p/>
* The provider can be injected or a FQCN can be specified as a system property. A default value will be used if no provider
* is injected nor configured as system property.
*
* @author <a href="mailto:sgonzalez@atricore.org">Sebastian Gonzalez Oyuela</a>
* @version $Id$
* @org.apache.xbean.XBean element="samlr2-signer"
*/
public class JSR105SamlR2SignerImpl implements SamlR2Signer {
/**
* The name of the system property that
*/
public static final String JSR105_PROVIDER_PROPERTY = "jsr105Provider";
/**
* Default JSR 105 Provider FQCN
*/
public static final String DEFAULT_JSR105_PROVIDER_FQCN = "org.jcp.xml.dsig.internal.dom.XMLDSigRI";
// TODO : Support SHA-256, make dynamic !
private static final String SHA1_WITH_DSA = "SHA1withDSA";
// TODO : Support SHA-256, make dynamic !
private static final String SHA1_WITH_RSA = "SHA1withRSA";
// TODO : Support SHA-256, make dynamic !
private static final String SHA256_WITH_RSA = "SHA256withRSA";
private static final Log logger = LogFactory.getLog(JSR105SamlR2SignerImpl.class);
/**
* JSR 105 Provider.
*/
private Provider provider;
private SSOKeyResolver keyResolver;
// Validate certificate expiration, CA, etc.
private boolean validateCertificate = false;
public Provider getProvider() {
return provider;
}
public void setProvider(Provider provider) {
this.provider = provider;
}
public SSOKeyResolver getKeyResolver() {
return keyResolver;
}
public boolean isValidateCertificate() {
return validateCertificate;
}
public void setValidateCertificate(boolean validateCertificate) {
this.validateCertificate = validateCertificate;
}
/**
* @org.apache.xbean.Property alias="key-resolver"
*/
public void setKeyResolver(SSOKeyResolver keyResolver) {
this.keyResolver = keyResolver;
}
public String getProviderFQCN() {
return System.getProperty(JSR105_PROVIDER_PROPERTY, DEFAULT_JSR105_PROVIDER_FQCN);
}
/**
* @org.apache.xbean.InitMethod
*/
public void init() {
InputStream is = null;
try {
// If a provider was already 'injected', use it.
if (provider == null) {
if (logger.isDebugEnabled())
logger.debug("Creating JSR 105 Provider : " + getProviderFQCN());
this.provider = (Provider) Class.forName(getProviderFQCN()).newInstance();
}
} catch (ClassNotFoundException e) {
throw new RuntimeException("Error creating default provider: " + getProviderFQCN(), e);
} catch (InstantiationException e) {
throw new RuntimeException("Error creating default provider: " + getProviderFQCN(), e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Error creating default provider: " + getProviderFQCN(), e);
}
}
public AssertionType sign(AssertionType assertion) throws SamlR2SignatureException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Assertion to DOM Tree [" + assertion.getID() + "]");
Document doc = XmlUtils.marshalSamlR2AsDom(assertion,
SAMLR2Constants.SAML_ASSERTION_NS,
"Assertion",
new String[]{SAMLR2Constants.SAML_ASSERTION_PKG});
doc = sign(doc, assertion.getID());
if (logger.isDebugEnabled())
logger.debug("Unmarshalling SAMLR2 Assertion from DOM Tree [" + assertion.getID() + "]");
return (AssertionType) XmlUtils.unmarshal(doc, new String[]{SAMLR2Constants.SAML_ASSERTION_PKG});
} catch (JAXBException e) {
throw new SamlR2SignatureException("JAXB Error signing SAMLR2 Assertion " + assertion.getID(), e);
} catch (ParserConfigurationException e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR2 Assertion " + assertion.getID(), e);
} catch (Exception e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR2 Assertion " + assertion.getID(), e);
}
}
public void validate(RoleDescriptorType md, AssertionType assertion) throws SamlR2SignatureException, SamlR2SignatureValidationException {
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Assertion to DOM Tree [" + assertion.getID() + "]");
try {
Document doc = XmlUtils.marshalSamlR2AssertionAsDom(assertion);
validate(md, doc);
} catch (Exception e) {
throw new SamlR2SignatureValidationException(e);
}
}
public RequestAbstractType sign(RequestAbstractType request) throws SamlR2SignatureException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Status Request to DOM Tree [" + request.getID() + "]");
org.w3c.dom.Document doc = XmlUtils.marshalSamlR2RequestAsDom(request);
doc = sign(doc, request.getID());
if (logger.isDebugEnabled())
logger.debug("Unmarshalling SAMLR2 Status Response from DOM Tree [" + request.getID() + "]");
// Unmarshall the assertion
return XmlUtils.unmarshalSamlR2Request(doc);
} catch (JAXBException e) {
throw new SamlR2SignatureException("JAXB Error signing SAMLR2 Response " + request.getID(), e);
} catch (ParserConfigurationException e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR2 Response " + request.getID(), e);
} catch (Exception e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR2 Response " + request.getID(), e);
}
}
public StatusResponseType sign(StatusResponseType response, String element) throws SamlR2SignatureException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Response to DOM Tree [" + response.getID() + "]");
Document doc = XmlUtils.marshalSamlR2AsDom(response,
SAMLR2Constants.SAML_PROTOCOL_NS,
element,
new String[]{
SAMLR2Constants.SAML_PROTOCOL_PKG,
SAMLR2Constants.SAML_ASSERTION_PKG});
doc = sign(doc, response.getID());
if (logger.isDebugEnabled())
logger.debug("Unmarshalling SAMLR2 Response from DOM Tree [" + response.getID() + "]");
// Unmarshall the response
return XmlUtils.unmarshalSamlR2Response(doc);
} catch (Exception e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR2 Response " + response.getID(), e);
}
}
public String signQueryString(String queryString) throws SamlR2SignatureException {
try {
if (queryString == null || queryString.length() == 0) {
logger.error("SAML 2.0 Qery string null");
throw new SamlR2SignatureException("SAML 2.0 Qery string null");
}
if (logger.isDebugEnabled())
logger.debug("Received SAML 2.0 Query string [" + queryString + "] for signing");
PrivateKey privateKey = (PrivateKey) this.getKeyResolver().getPrivateKey();
String keyAlgorithm = privateKey.getAlgorithm();
Signature signature = null;
String algURI = null;
if (keyAlgorithm.equals("RSA")) {
signature = Signature.getInstance(SHA1_WITH_RSA);
algURI = SignatureMethod.RSA_SHA1;
} else if (keyAlgorithm.equals("DSA")) {
signature = Signature.getInstance(SHA1_WITH_DSA);
algURI = SignatureMethod.DSA_SHA1;
} else {
throw new SamlR2SignatureException("SAML 2.0 Signature does not support provided key's algorithm " + keyAlgorithm);
}
if (queryString.charAt(queryString.length() - 1) != '&') {
queryString = queryString + "&";
}
queryString += "SigAlg=" +
URLEncoder.encode(algURI, "UTF-8");
if (logger.isTraceEnabled())
logger.trace("Signing SAML 2.0 Query string [" + queryString + "]");
signature.initSign(privateKey);
signature.update(queryString.getBytes());
byte[] sigBytes = null;
sigBytes = signature.sign();
if (sigBytes == null || sigBytes.length == 0) {
logger.error("Cannot generate signed query string, Signature created 'null' value.");
throw new SamlR2SignatureException("Cannot generate signed query string, Signature created 'null' value.");
}
Base64 encoder = new Base64();
String encodedSig = new String(encoder.encode(sigBytes), "UTF-8");
queryString +=
"&Signature=" +
URLEncoder.encode(encodedSig, "UTF-8");
if (logger.isTraceEnabled())
logger.trace("Signed SAML 2.0 Query string [" + queryString + "]");
return queryString;
} catch (Exception e) {
throw new SamlR2SignatureException("Error generating SAML 2.0 Query string signature " + e.getMessage(), e);
}
}
public void validate(RoleDescriptorType md, StatusResponseType response, String element) throws SamlR2SignatureException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Status Response to DOM Tree [" + response.getID() + "]");
Document doc = XmlUtils.marshalSamlR2ResponseAsDom(response, element);
validate(md, doc);
} catch (Exception e) {
throw new SamlR2SignatureException("Error verifying signature for SAMLR2 response" + response.getID(), e);
}
}
public void validate(RoleDescriptorType md, AuthnRequestType request) throws SamlR2SignatureException, SamlR2SignatureValidationException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Status Authn Request to DOM Tree [" + request.getID() + "]");
Document doc = XmlUtils.marshalSamlR2RequestAsDom(request);
validate(md, doc);
} catch (Exception e) {
throw new SamlR2SignatureException("Error verifying signature for SAMLR2 authn request " + request.getID(), e);
}
}
public void validateQueryString(RoleDescriptorType md, String queryString) throws SamlR2SignatureException, SamlR2SignatureValidationException {
try {
X509Certificate cert = getX509Certificate(md);
if(cert == null) {
logger.error("No Certificate found in Metadata " + md.getID());
throw new SamlR2SignatureException("No Certificate found in Metadata " + md.getID());
}
if (queryString == null || queryString.length() == 0) {
logger.error("SAML 2.0 Qery string null");
throw new SamlR2SignatureException("SAML 2.0 Qery string null");
}
if (logger.isTraceEnabled())
logger.trace("SAML 2.0 Query string to validate ["+queryString+"]");
StringTokenizer st = new StringTokenizer(queryString, "&");
String samlParam;
String samlRequest = null;
String samlResponse = null;
String relayState = null;
String sigAlg = null;
String encSig = null;
while (st.hasMoreTokens()) {
samlParam = st.nextToken();
if (samlParam.startsWith("SAMLRequest")) {
samlRequest = samlParam;
} else if (samlParam.startsWith("SAMLResponse")) {
samlResponse = samlParam;
} else if (samlParam.startsWith("RelayState")) {
relayState = samlParam;
} else if (samlParam.startsWith("SigAlg")) {
sigAlg = samlParam;
} else if (samlParam.startsWith("Signature")) {
encSig = samlParam;
} else {
// Ignore this token ...
logger.warn("Non-SAML 2.0 parameter ignored " + samlParam);
}
}
if ((samlRequest == null || samlRequest.equals("")) &&
(samlResponse == null || samlResponse.equals("")))
throw new SamlR2SignatureValidationException("SAML 2.0 Query string MUST contain either 'SAMLRequest' or 'SAMLResponse' parameter");
if (sigAlg == null || sigAlg.equals(""))
throw new SamlR2SignatureValidationException("SAML 2.0 Query string MUST contain a 'SigAlg' parameter");
if (encSig == null || encSig.equals("")) {
throw new SamlR2SignatureValidationException("SAML 2.0 Query string MUST contain a 'Signature' parameter");
}
// Re-order paramters just in case they were mixed-up while getting here.
String newQueryString = null;
if (samlRequest != null) {
newQueryString = samlRequest;
} else {
newQueryString = samlResponse;
}
if (relayState != null) {
newQueryString += "&" + relayState;
}
newQueryString += "&" + sigAlg;
if (logger.isDebugEnabled())
logger.debug("SAML 2.0 Query string signature validation for (re-arranged) [" + newQueryString + "]");
int sigAlgValueIndex = sigAlg.indexOf('=');
// Get Signature Algorithm
String sigAlgValue =
sigAlg.substring(sigAlgValueIndex + 1);
if (sigAlgValue == null || sigAlgValue.equals("")) {
throw new SamlR2SignatureValidationException("SAML 2.0 Query string MUST contain a 'SigAlg' parameter value");
}
sigAlgValue = URLDecoder.decode(sigAlgValue, "UTF-8");
if (logger.isTraceEnabled())
logger.trace("SigAlg=" + sigAlgValue);
// Get Signature value
int encSigValueIndex = encSig.indexOf('=');
String signatureEnc = encSig.substring(encSigValueIndex + 1);
if (signatureEnc == null || signatureEnc.equals("")) {
throw new SamlR2SignatureValidationException("SAML 2.0 Query string MUST contain a 'Signature' parameter value");
}
signatureEnc = URLDecoder.decode(signatureEnc, "UTF-8");
if (logger.isTraceEnabled())
logger.trace("Signature=" + signatureEnc);
// base-64 decode the signature value
byte[] signatureBin = null;
Base64 decoder = new Base64();
signatureBin = decoder.decode(signatureEnc.getBytes());
// get Signature instance based on algorithm
// TODO : Support SHA-256
Signature signature = null;
if (sigAlgValue.equals(SignatureMethod.DSA_SHA1)) {
signature = Signature.getInstance(SHA1_WITH_DSA);
} else if (sigAlgValue.equals(SignatureMethod.RSA_SHA1)) {
signature = Signature.getInstance(SHA1_WITH_RSA);
} else {
throw new SamlR2SignatureException("SAML 2.0 Siganture does not support algorithm " + sigAlgValue);
}
// now verify signature
signature.initVerify(cert);
signature.update(newQueryString.getBytes());
if (!signature.verify(signatureBin)) {
// TODO : Get information about the error ?!
throw new SamlR2SignatureValidationException("Invalid digital signature");
}
if (!validateCertificate(md, null)) {
throw new SamlR2SignatureValidationException("Certificate is not valid, check logs for details");
}
} catch (Exception e) {
logger.error("Cannot verify digital SAML 2.0 Query string signature " + e.getMessage(), e);
throw new SamlR2SignatureException("Cannot verify digital SAML 2.0 Query string signature " + e.getMessage(), e);
}
}
public void validateQueryString(RoleDescriptorType md, String msg, String relayState, String sigAlg, String signature, boolean isResponse) throws SamlR2SignatureException, SamlR2SignatureValidationException {
try {
String queryStr = ( isResponse ? "SAMLResponse=" : "SAMLRequest=" ) +
URLEncoder.encode(msg, "UTF-8") + "&" +
(relayState != null && !"".equals(relayState) ? "RelayState=" + relayState + "&" : "") +
"SigAlg=" + URLEncoder.encode(sigAlg, "UTF-8") + "&" +
"Signature=" + URLEncoder.encode(signature, "UTF-8");
} catch (UnsupportedEncodingException e) {
logger.error("Cannot verify digital SAML 2.0 Query string signature " + e.getMessage(), e);
throw new SamlR2SignatureException("Cannot verify digital SAML 2.0 Query string signature " + e.getMessage(), e);
}
}
public void validate(RoleDescriptorType md, LogoutRequestType request) throws SamlR2SignatureException, SamlR2SignatureValidationException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 Logout Request to DOM Tree [" + request.getID() + "]");
Document doc = XmlUtils.marshalSamlR2RequestAsDom(request);
validate(md, doc);
} catch (Exception e) {
throw new SamlR2SignatureException("Error verifying signature for SAMLR2 response" + request.getID(), e);
}
}
public void validate(RoleDescriptorType md, ManageNameIDRequestType manageNameIDRequest) throws SamlR2SignatureException {
try {
// Marshall the ManageNameID object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 ManageNameID to DOM Tree [" + manageNameIDRequest.getID() + "]");
Document doc = XmlUtils.marshalSamlR2RequestAsDom(manageNameIDRequest);
validate(md, doc);
} catch (Exception e) {
throw new SamlR2SignatureException("XML Parser Error verifying SAMLR2 Response signature " + manageNameIDRequest.getID(), e);
}
}
public ManageNameIDRequestType sign(ManageNameIDRequestType manageNameIDRequest) throws SamlR2SignatureException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR2 ManageNameIDRequestType to DOM Tree [" + manageNameIDRequest.getID() + "]");
org.w3c.dom.Document doc = XmlUtils.marshalSamlR2RequestAsDom(manageNameIDRequest);
doc = sign(doc, manageNameIDRequest.getID());
if (logger.isDebugEnabled())
logger.debug("Unmarshalling SAMLR2 Assertion from DOM Tree [" + manageNameIDRequest.getID() + "]");
return (ManageNameIDRequestType) XmlUtils.unmarshal(doc, new String[]{SAMLR2Constants.SAML_PROTOCOL_NS});
} catch (Exception e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR2 Assertion " + manageNameIDRequest.getID(), e);
}
}
// SAML 1.1
public oasis.names.tc.saml._1_0.protocol.ResponseType sign(oasis.names.tc.saml._1_0.protocol.ResponseType response) throws SamlR2SignatureException {
try {
// Marshall the Assertion object as a DOM tree:
if (logger.isDebugEnabled())
logger.debug("Marshalling SAMLR11 Response to DOM Tree [" + response.getResponseID() + "]");
// Instantiate the document to be signed
javax.xml.parsers.DocumentBuilderFactory dbf =
javax.xml.parsers.DocumentBuilderFactory.newInstance();
// XML Signature needs to be namespace aware
dbf.setNamespaceAware(true);
javax.xml.parsers.DocumentBuilder db = dbf.newDocumentBuilder();
JAXBContext context = JAXBContext.newInstance(SAMLR11Constants.SAML_PROTOCOL_PKG,
response.getClass().getClassLoader());
Marshaller m = context.createMarshaller();
Class<oasis.names.tc.saml._1_0.protocol.ResponseType> clazz =
(Class<oasis.names.tc.saml._1_0.protocol.ResponseType>) response.getClass();
// Remove the 'Type' suffix from the xml type name and use it as XML element!
XmlType t = clazz.getAnnotation(XmlType.class);
String element = t.name().substring(0, t.name().length() - 4);
JAXBElement<oasis.names.tc.saml._1_0.protocol.ResponseType> jaxbResponse =
new JAXBElement<oasis.names.tc.saml._1_0.protocol.ResponseType>(
new QName(SAMLR11Constants.SAML_PROTOCOL_NS, element),
clazz,
response);
// remove prefixes from signature elements of embedded signed assertion so that signature validation -
// which removes those prefixes - doesn't fail
StringWriter swrsp = new StringWriter();
XMLStreamWriter sw = new NamespaceFilterXMLStreamWriter(swrsp);
// TODO : Use XML Utils!!!!
m.marshal(jaxbResponse, sw);
sw.flush();
Document doc =
dbf.newDocumentBuilder().parse(new ByteArrayInputStream(swrsp.toString().getBytes()));
doc = sign(doc, response.getResponseID());
if (logger.isDebugEnabled())
logger.debug("Unmarshalling SAMLR11 Response from DOM Tree [" + response.getResponseID() + "]");
// Unmarshall the assertion
Unmarshaller u = context.createUnmarshaller();
jaxbResponse = (JAXBElement<oasis.names.tc.saml._1_0.protocol.ResponseType>) u.unmarshal(doc);
return jaxbResponse.getValue();
} catch (JAXBException e) {
throw new SamlR2SignatureException("JAXB Error signing SAMLR11 Response " + response.getResponseID(), e);
} catch (ParserConfigurationException e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR11 Response " + response.getResponseID(), e);
} catch (XMLStreamException e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR11 Response " + response.getResponseID(), e);
} catch (IOException e) {
throw new SamlR2SignatureException("I/O Error signing SAMLR11 Response " + response.getResponseID(), e);
} catch (SAXException e) {
throw new SamlR2SignatureException("XML Parser Error signing SAMLR11 Response " + response.getResponseID(), e);
}
}
// Primitives
public void validateDom(RoleDescriptorType md, String domStr) throws SamlR2SignatureException {
try {
javax.xml.parsers.DocumentBuilderFactory dbf =
javax.xml.parsers.DocumentBuilderFactory.newInstance();
dbf.setNamespaceAware(true);
javax.xml.parsers.DocumentBuilder db = dbf.newDocumentBuilder();
Document doc = db.parse(new ByteArrayInputStream(domStr.getBytes()));
validate(md, doc);
} catch (ParserConfigurationException e) {
throw new SamlR2SignatureException(e);
} catch (SAXException e) {
throw new SamlR2SignatureException(e);
} catch (IOException e) {
throw new SamlR2SignatureException(e);
}
}
public void validateDom(RoleDescriptorType md, String domStr, String elementId) throws SamlR2SignatureException {
try {
javax.xml.parsers.DocumentBuilderFactory dbf =
javax.xml.parsers.DocumentBuilderFactory.newInstance();
dbf.setNamespaceAware(true);
javax.xml.parsers.DocumentBuilder db = dbf.newDocumentBuilder();
Document doc = db.parse(new ByteArrayInputStream(domStr.getBytes()));
NodeList nodes = evaluateXPath(doc, "//*[@ID='"+elementId+"']");
if (nodes.getLength() > 1)
throw new SamlR2SignatureException("Duplicate ID ["+elementId+"] in document ");
if (nodes.getLength() < 1)
throw new SamlR2SignatureException("Invalid element ID " + elementId);
validate(md, doc, nodes.item(0));
} catch (ParserConfigurationException e) {
throw new SamlR2SignatureException(e);
} catch (SAXException e) {
throw new SamlR2SignatureException(e);
} catch (IOException e) {
throw new SamlR2SignatureException(e);
}
}
public void validate(RoleDescriptorType md, Document doc, Node root) throws SamlR2SignatureException {
try {
// Check for duplicate IDs among XML elements
NodeList nodes = evaluateXPath(doc, "//*/@ID");
boolean duplicateIdExists = false;
List<String> ids = new ArrayList<String>();
for (int i = 0; i < nodes.getLength(); i++) {
Node node = nodes.item(i);
if (ids.contains(node.getNodeValue())) {
duplicateIdExists = true;
logger.error("Duplicated Element ID in XML Document : " + node.getNodeValue());
}
ids.add(node.getNodeValue());
}
if (duplicateIdExists) {
throw new SamlR2SignatureException("Duplicate IDs in document ");
}
// TODO : Check that the Signature references the root element (the one used by the application)
// Keep in mind that signature reference might be an XPath expression ?!
// We know that in SAML, the root element is the element used by the application, we just need to make sure that
// the root element is the one referred by the signature
Node rootIdAttr = root.getAttributes().getNamedItem("ID");
if (rootIdAttr == null)
throw new SamlR2SignatureException("SAML document does not have an ID ");
// Find Signature element
NodeList signatureNodes =
doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature");
if (signatureNodes.getLength() == 0) {
throw new SamlR2SignatureException("Cannot find Signature elements");
}
// Create a DOM XMLSignatureFactory that will be used to unmarshal the
// document containing the XMLSignature
XMLSignatureFactory fac = XMLSignatureFactory.getInstance("DOM", provider);
// Create a DOMValidateContext and specify a KeyValue KeySelector
// and document context
// Validate all Signature elements
boolean rootIdMatched = false;
for (int k = 0; k < signatureNodes.getLength(); k++) {
DOMValidateContext valContext = new DOMValidateContext
(new RawX509KeySelector(), signatureNodes.item(k));
// unmarshal the XMLSignature
XMLSignature signature = fac.unmarshalXMLSignature(valContext);
// Validate the XMLSignature (generated above)
boolean coreValidity = signature.validate(valContext);
// Check core validation status
if (!coreValidity) {
if (logger.isDebugEnabled())
logger.debug("Signature failed core validation");
boolean sv = signature.getSignatureValue().validate(valContext);
if (logger.isDebugEnabled())
logger.debug("signature validation status: " + sv);
// check the validation status of each Reference (should be only one!)
Iterator i = signature.getSignedInfo().getReferences().iterator();
boolean refValid = true;
for (int j = 0; i.hasNext(); j++) {
Reference ref = (Reference) i.next();
boolean b = ref.validate(valContext);
if (logger.isDebugEnabled())
logger.debug("ref[" + j + "] " + ref.getId() + " validity status: " + b);
if (!b) {
refValid = b;
logger.error("Signature failed reference validation " + ref.getId());
}
}
throw new SamlR2SignatureValidationException("Signature failed core validation" + (refValid ? " but passed all Reference validations" : " and some/all Reference validation"));
}
if (logger.isDebugEnabled())
logger.debug("Singnature passed Core validation");
// The Signature must contain only one reference, and it must be the signed top element's ID.
List<Reference> refs = signature.getSignedInfo().getReferences();
if (refs.size() != 1) {
throw new SamlR2SignatureValidationException("Invalid number of 'Reference' elements in signature : "
+ refs.size() + " [" + signature.getId() + "]");
}
Reference reference = refs.get(0);
String referenceURI = reference.getURI();
if (referenceURI == null || !referenceURI.startsWith("#"))
throw new SamlR2SignatureValidationException("Signature reference URI format not supported " + referenceURI);
if (referenceURI.substring(1).equals(rootIdAttr.getNodeValue()))
rootIdMatched = true;
Key key = signature.getKeySelectorResult().getKey();
boolean certValidity = validateCertificate(md, key);
if (!certValidity) {
throw new SamlR2SignatureValidationException("Signature failed Certificate validation");
}
if (logger.isDebugEnabled())
logger.debug("Signature passed Certificate validation");
}
// Check that any of the Signatures matched the root element ID
if (!rootIdMatched) {
logger.error("No Signature element refers to signed element (possible signature wrapping attack)");
throw new SamlR2SignatureValidationException("No Signature element refers to signed element");
}
} catch (MarshalException e) {
throw new RuntimeException(e.getMessage(), e);
} catch (XMLSignatureException e) {
throw new RuntimeException(e.getMessage(), e);
}
}
/**
* This validates XML Didgital signature for SAML 2.0 Documents (requests, responses, assertions, etc)
*
* @param md The signer SAML 2.0 Metadata
* @param doc DOM representation of the document
* @throws SamlR2SignatureException if the signature is invalid
*/
public void validate(RoleDescriptorType md, Document doc) throws SamlR2SignatureException {
Node root = doc.getDocumentElement();
validate(md, doc, root);
}
protected byte[] getBinCertificate(RoleDescriptorType md) {
byte[] x509CertificateBin = null;
if (md.getKeyDescriptor() != null && md.getKeyDescriptor().size() > 0) {
for (KeyDescriptorType keyMd : md.getKeyDescriptor()) {
if (!keyMd.getUse().equals(KeyTypes.SIGNING))
continue;
if (keyMd.getKeyInfo() != null) {
// Get inside Key Info
List contentMd = keyMd.getKeyInfo().getContent();
if (contentMd != null && contentMd.size() > 0) {
for (Object o : contentMd) {
if (o instanceof JAXBElement) {
JAXBElement e = (JAXBElement) o;
if (e.getValue() instanceof X509DataType) {
X509DataType x509Data = (X509DataType) e.getValue();
for (Object x509Content : x509Data.getX509IssuerSerialOrX509SKIOrX509SubjectName()) {
if (x509Content instanceof JAXBElement) {
JAXBElement x509Certificate = (JAXBElement) x509Content;
if (x509Certificate.getName().getNamespaceURI().equals("http://www.w3.org/2000/09/xmldsig#") &&
x509Certificate.getName().getLocalPart().equals("X509Certificate")) {
x509CertificateBin = (byte[]) x509Certificate.getValue();
break;
}
}
}
}
}
if (x509CertificateBin != null)
break;
}
}
} else {
logger.debug("Metadata Key Descriptor does not have KeyInfo " + keyMd.toString());
}
if (x509CertificateBin != null)
break;
}
} else {
logger.debug("Metadata does not have Key Descriptors: " + md.getID());
}
if (logger.isTraceEnabled()) {
logger.trace("MD Sign Certificate: " + Arrays.toString(x509CertificateBin));
}
return x509CertificateBin;
}
protected X509Certificate getX509Certificate(RoleDescriptorType md) {
byte[] x509CertificateBin = getBinCertificate(md);
if (x509CertificateBin == null)
return null;
try {
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate x509Cert = (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(x509CertificateBin));
return x509Cert;
} catch (CertificateException e) {
logger.error("Cannot get X509 Certificate " + e.getMessage(), e);
}
return null;
}
protected boolean validateCertificate(RoleDescriptorType md, Key publicKey) {
/*
X509Certificate x509Cert = getX509Certificate(md);
if (x509Cert == null) {
logger.error("No X509 Signing certificate found in SAML 2.0 Metadata Role " + md.getID());
return false;
}
if (logger.isTraceEnabled()) {
byte[] x509CertificateBin =getBinCertificate(md);
logger.trace("Configured Certificate: " + (publicKey != null ? Arrays.toString(publicKey.getEncoded()) : "<null>"));
logger.trace("Used Certificate: " + Arrays.toString(x509CertificateBin));
}
PublicKey x509PublicKey = x509Cert.getPublicKey();
byte[] x509PublicKeyEncoded = x509PublicKey.getEncoded();
// Only compare with public key, if provided.
if (publicKey != null) {
byte[] publicKeyEncoded = publicKey.getEncoded();
// Validate that the used certificate is the one configured for the entity
if (!java.util.Arrays.equals(x509PublicKeyEncoded, publicKeyEncoded)) {
logger.error("Certificate used for signing is not the one configured in SAML 2.0 Metadata Role " + md.getID());
return false;
}
}
Date now = new Date();
if (x509Cert.getNotBefore() != null && x509Cert.getNotBefore().before(now)) {
if (validateCertificate) {
logger.error("Certificate should not be used before " + x509Cert.getNotBefore());
return false;
}
logger.warn("Certificate should not be used before " + x509Cert.getNotBefore());
}
if (x509Cert.getNotAfter() != null && x509Cert.getNotAfter().after(now)) {
if (validateCertificate) {
logger.error("X509 Certificate has expired " + x509Cert.getNotAfter());
return false;
}
logger.warn("X509 Certificate has expired " + x509Cert.getNotAfter());
}
Calendar aMonthFromNow = Calendar.getInstance();
aMonthFromNow.add(Calendar.DAY_OF_MONTH, 30);
// Just print-out that the certificate will expire soon.
if (x509Cert.getNotAfter().after(aMonthFromNow.getTime()))
logger.warn("X509 Certificate wil expired in less that 30 days for SAML 2.0 Metadata Role " + md.getID());
// TODO : Validate CRLs , etc !!!!
*/
return true;
}
/**
* This will sign a SAMLR2 Identity artifact (assertion, request or response) represeted as a DOM tree
* The signature will be inserted as the first child of the root element.
*
* @param doc
* @param id
* @return
*/
protected Document sign(Document doc, String id) throws SamlR2SignatureException {
try {
Certificate cert = keyResolver.getCertificate();
// Create a DOM XMLSignatureFactory that will be used to generate the
// enveloped signature
XMLSignatureFactory fac = XMLSignatureFactory.getInstance("DOM", provider);
if (logger.isDebugEnabled())
logger.debug("Creating XML DOM Digital Siganture (not signing yet!)");
// Create a Reference to the enveloped document and
// also specify the SHA1 digest algorithm and the ENVELOPED Transform.
// The URI must be the assertion ID
List<Transform> transforms = new ArrayList<Transform>();
transforms.add(fac.newTransform(Transform.ENVELOPED, (TransformParameterSpec) null));
// Magically, this solves assertion DS validation when embedded in a signed response :)
transforms.add(fac.newTransform(CanonicalizationMethod.EXCLUSIVE, (TransformParameterSpec) null));
Reference ref = fac.newReference
("#" + id,
fac.newDigestMethod(DigestMethod.SHA1, null),
transforms,
null, null);
// Use signature method based on key algorithm.
String signatureMethod = SignatureMethod.DSA_SHA1;
if (keyResolver.getPrivateKey().getAlgorithm().equals("RSA"))
signatureMethod = SignatureMethod.RSA_SHA1;
logger.debug("Using signature method " + signatureMethod);
// Create the SignedInfo, with the X509 Certificate
/*
SignedInfo si = fac.newSignedInfo
(fac.newCanonicalizationMethod
(CanonicalizationMethod.INCLUSIVE_WITH_COMMENTS,
(C14NMethodParameterSpec) null),
fac.newSignatureMethod(signatureMethod, null),
Collections.singletonList(ref));
*/
SignedInfo si = fac.newSignedInfo
(fac.newCanonicalizationMethod
(CanonicalizationMethod.EXCLUSIVE,
(C14NMethodParameterSpec) null),
fac.newSignatureMethod(signatureMethod, null),
Collections.singletonList(ref));
// Create a KeyInfo and add the Certificate to it
KeyInfoFactory kif = fac.getKeyInfoFactory();
X509Data kv = kif.newX509Data(Collections.singletonList(cert));
//KeyValue kv = kif.newKeyValue(keyResolver.getCertificate().getPublicKey());
KeyInfo ki = kif.newKeyInfo(Collections.singletonList(kv));
javax.xml.crypto.dsig.XMLSignature signature = fac.newXMLSignature(si, ki);
if (logger.isDebugEnabled())
logger.debug("Signing SAMLR2 Identity Artifact ...");
// Create a DOMSignContext and specify the DSA PrivateKey and
// location of the resulting XMLSignature's parent element
DOMSignContext dsc = new DOMSignContext
(keyResolver.getPrivateKey(), doc.getDocumentElement(), doc.getDocumentElement().getFirstChild());
// Sign the assertion
signature.sign(dsc);
if (logger.isDebugEnabled())
logger.debug("Signing SAMLR2 Identity Artifact ... DONE!");
return doc;
} catch (NoSuchAlgorithmException e) {
throw new SamlR2SignatureException(e.getMessage(), e);
} catch (XMLSignatureException e) {
throw new SamlR2SignatureException(e.getMessage(), e);
} catch (InvalidAlgorithmParameterException e) {
throw new SamlR2SignatureException(e.getMessage(), e);
} catch (MarshalException e) {
throw new SamlR2SignatureException(e.getMessage(), e);
} catch (SSOKeyResolverException e) {
throw new SamlR2SignatureException(e.getMessage(), e);
}
}
protected NodeList evaluateXPath(Document doc, String expression) throws SamlR2SignatureException {
XPathFactory factory = XPathFactory.newInstance();
XPath xpath = factory.newXPath();
xpath.setNamespaceContext(getNamespaceContext());
NodeList nl;
try {
XPathExpression expr = xpath.compile(expression);
nl = (NodeList) expr.evaluate(doc, XPathConstants.NODESET);
} catch (XPathExpressionException e) {
throw new SamlR2SignatureException(e);
}
return nl;
}
protected NamespaceContext getNamespaceContext() {
return new NamespaceContext() {
public String getNamespaceURI(String prefix) {
if (prefix.equals("ds"))
return org.apache.xml.security.utils.Constants.SignatureSpecNS;
if (prefix.equals("p"))
return SAMLR2Constants.SAML_PROTOCOL_NS;
if (prefix.equals("a"))
return SAMLR2Constants.SAML_ASSERTION_NS;
return null;
}
// Dummy implementation - not used!
public Iterator getPrefixes(String val) {
return null;
}
// Dummy implemenation - not used!
public String getPrefix(String uri) {
if (uri.equals(org.apache.xml.security.utils.Constants.SignatureSpecNS))
return "ds";
if (uri.equals(SAMLR2Constants.SAML_PROTOCOL_NS))
return "p";
if (uri.equals(SAMLR2Constants.SAML_ASSERTION_NS))
return "a";
return null;
}
};
}
/**
* KeySelector which retrieves the public key out of the
* KeyValue element and returns it.
* NOTE: If the key algorithm doesn't match signature algorithm,
* then the public key will be ignored.
*/
private static class KeyValueKeySelector extends KeySelector {
public KeySelectorResult select(KeyInfo keyInfo,
KeySelector.Purpose purpose,
AlgorithmMethod method,
XMLCryptoContext context)
throws KeySelectorException {
if (keyInfo == null) {
throw new KeySelectorException("Null KeyInfo object!");
}
SignatureMethod sm = (SignatureMethod) method;
List list = keyInfo.getContent();
for (Object aList : list) {
XMLStructure xmlStructure = (XMLStructure) aList;
if (xmlStructure instanceof KeyValue) {
PublicKey pk = null;
try {
pk = ((KeyValue) xmlStructure).getPublicKey();
} catch (KeyException ke) {
throw new KeySelectorException(ke);
}
// make sure algorithm is compatible with method
if (algEquals(sm.getAlgorithm(), pk.getAlgorithm())) {
return new SimpleKeySelectorResult(pk);
}
}
}
throw new KeySelectorException("No KeyValue element found!");
}
static boolean algEquals(String algURI, String algName) {
if (algName.equalsIgnoreCase("DSA") &&
algURI.equalsIgnoreCase(SignatureMethod.DSA_SHA1)) {
return true;
} else if (algName.equalsIgnoreCase("RSA") &&
algURI.equalsIgnoreCase(SignatureMethod.RSA_SHA1)) {
return true;
} else {
logger.error("Unsupported Key Algorithm found in signature: " + algName);
return false;
}
}
}
/**
* KeySelector which would retrieve the X509Certificate out of the
* KeyInfo element and return the public key.
* NOTE: If there is an X509CRL in the KeyInfo element, then revoked
* certificate will be ignored.
*/
public static class RawX509KeySelector extends KeySelector {
public KeySelectorResult select(KeyInfo keyInfo,
KeySelector.Purpose purpose,
AlgorithmMethod method,
XMLCryptoContext context)
throws KeySelectorException {
if (keyInfo == null) {
throw new KeySelectorException("Null KeyInfo object!");
}
// search for X509Data in keyinfo
Iterator iter = keyInfo.getContent().iterator();
while (iter.hasNext()) {
XMLStructure kiType = (XMLStructure) iter.next();
if (kiType instanceof X509Data) {
X509Data xd = (X509Data) kiType;
Object[] entries = xd.getContent().toArray();
X509CRL crl = null;
// Looking for CRL before finding certificates
for (int i = 0; (i < entries.length && crl != null); i++) {
if (entries[i] instanceof X509CRL) {
crl = (X509CRL) entries[i];
}
}
Iterator xi = xd.getContent().iterator();
boolean hasCRL = false;
while (xi.hasNext()) {
Object o = xi.next();
// skip non-X509Certificate entries
if (o instanceof X509Certificate) {
if ((purpose != KeySelector.Purpose.VERIFY) &&
(crl != null) &&
crl.isRevoked((X509Certificate) o)) {
continue;
} else {
return new SimpleKeySelectorResult
(((X509Certificate) o).getPublicKey());
}
}
}
}
}
throw new KeySelectorException("No X509Certificate found!");
}
}
private static class SimpleKeySelectorResult implements KeySelectorResult {
private PublicKey pk;
SimpleKeySelectorResult(PublicKey pk) {
this.pk = pk;
}
public Key getKey() {
return pk;
}
}
@Override
public String toString
() {
return super.toString() + "[provider.name=" + provider.getName() +
"provider.info=" + provider.getInfo() +
",keyResolver=" + keyResolver +
"]";
}
}