package org.apereo.cas.support.saml.util;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.commons.lang3.StringUtils;
import org.apereo.cas.support.saml.OpenSamlConfigBean;
import org.apereo.cas.util.EncodingUtils;
import org.jdom.Document;
import org.jdom.input.DOMBuilder;
import org.jdom.input.SAXBuilder;
import org.jdom.output.XMLOutputter;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallerFactory;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.impl.XSStringBuilder;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.common.SAMLObjectBuilder;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.soap.common.SOAPObject;
import org.opensaml.soap.common.SOAPObjectBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import javax.xml.XMLConstants;
import javax.xml.crypto.dsig.CanonicalizationMethod;
import javax.xml.crypto.dsig.DigestMethod;
import javax.xml.crypto.dsig.Reference;
import javax.xml.crypto.dsig.SignatureMethod;
import javax.xml.crypto.dsig.SignedInfo;
import javax.xml.crypto.dsig.Transform;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.crypto.dsig.XMLSignatureFactory;
import javax.xml.crypto.dsig.dom.DOMSignContext;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyInfoFactory;
import javax.xml.crypto.dsig.keyinfo.KeyValue;
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
import javax.xml.crypto.dsig.spec.TransformParameterSpec;
import javax.xml.namespace.QName;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import java.io.ByteArrayInputStream;
import java.io.Serializable;
import java.io.StringWriter;
import java.lang.reflect.Field;
import java.nio.charset.Charset;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/**
* An abstract builder to serve as the template handler
* for SAML1 and SAML2 responses.
*
* @author Misagh Moayyed mmoayyed@unicon.net
* @since 4.1
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
public abstract class AbstractSamlObjectBuilder implements Serializable {
/**
* The constant DEFAULT_ELEMENT_NAME_FIELD.
*/
protected static final String DEFAULT_ELEMENT_NAME_FIELD = "DEFAULT_ELEMENT_NAME";
/**
* The constant DEFAULT_ELEMENT_LOCAL_NAME_FIELD.
*/
protected static final String DEFAULT_ELEMENT_LOCAL_NAME_FIELD = "DEFAULT_ELEMENT_LOCAL_NAME";
private static final int RANDOM_ID_SIZE = 16;
private static final String SIGNATURE_FACTORY_PROVIDER_CLASS = "org.jcp.xml.dsig.internal.dom.XMLDSigRI";
private static final long serialVersionUID = -6833230731146922780L;
private static final String NAMESPACE_URI = "http://www.w3.org/2000/xmlns/";
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractSamlObjectBuilder.class);
/**
* The Config bean.
*/
protected OpenSamlConfigBean configBean;
public AbstractSamlObjectBuilder(final OpenSamlConfigBean configBean) {
this.configBean = configBean;
}
/**
* Create a new SAML object.
*
* @param <T> the generic type
* @param objectType the object type
* @return the t
*/
public <T extends SAMLObject> T newSamlObject(final Class<T> objectType) {
final QName qName = getSamlObjectQName(objectType);
final SAMLObjectBuilder<T> builder = (SAMLObjectBuilder<T>)
XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName);
if (builder == null) {
throw new IllegalStateException("No SAML object builder is registered for class " + objectType.getName());
}
return objectType.cast(builder.buildObject(qName));
}
/**
* New soap object t.
*
* @param <T> the type parameter
* @param objectType the object type
* @return the t
*/
public <T extends SOAPObject> T newSoapObject(final Class<T> objectType) {
final QName qName = getSamlObjectQName(objectType);
final SOAPObjectBuilder<T> builder = (SOAPObjectBuilder<T>) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName);
if (builder == null) {
throw new IllegalStateException("No SAML object builder is registered for class " + objectType.getName());
}
return objectType.cast(builder.buildObject(qName));
}
/**
* Gets saml object QName.
*
* @param objectType the object type
* @return the saml object QName
* @throws RuntimeException the exception
*/
public QName getSamlObjectQName(final Class objectType) throws RuntimeException {
try {
final Field f = objectType.getField(DEFAULT_ELEMENT_NAME_FIELD);
return (QName) f.get(null);
} catch (final NoSuchFieldException e) {
throw new IllegalStateException("Cannot find field " + objectType.getName() + '.' + DEFAULT_ELEMENT_NAME_FIELD, e);
} catch (final IllegalAccessException e) {
throw new IllegalStateException("Cannot access field " + objectType.getName() + '.' + DEFAULT_ELEMENT_NAME_FIELD, e);
}
}
/**
* New attribute value.
*
* @param value the value
* @param elementName the element name
* @return the xS string
*/
protected XSString newAttributeValue(final Object value, final QName elementName) {
final XSStringBuilder attrValueBuilder = new XSStringBuilder();
final XSString stringValue = attrValueBuilder.buildObject(elementName, XSString.TYPE_NAME);
if (value instanceof String) {
stringValue.setValue((String) value);
} else {
stringValue.setValue(value.toString());
}
return stringValue;
}
/**
* Generate a secure random id.
*
* @return the secure id string
*/
public String generateSecureRandomId() {
try {
final SecureRandom random = SecureRandom.getInstance("SHA1PRNG");
final byte[] buf = new byte[RANDOM_ID_SIZE];
random.nextBytes(buf);
return "_".concat(EncodingUtils.hexEncode(buf));
} catch (final Exception e) {
throw new IllegalStateException("Cannot create secure random ID generator for SAML message IDs.", e);
}
}
/**
* Add attribute values to saml attribute.
*
* @param attributeName the attribute name
* @param attributeValue the attribute value
* @param attributeList the attribute list
* @param defaultElementName the default element name
*/
protected void addAttributeValuesToSamlAttribute(final String attributeName,
final Object attributeValue,
final List<XMLObject> attributeList,
final QName defaultElementName) {
if (attributeValue == null) {
LOGGER.debug("Skipping over SAML attribute [{}] since it has no value", attributeName);
return;
}
LOGGER.debug("Attempting to generate SAML attribute [{}] with value(s) [{}]", attributeName, attributeValue);
if (attributeValue instanceof Collection<?>) {
final Collection<?> c = (Collection<?>) attributeValue;
LOGGER.debug("Generating multi-valued SAML attribute [{}] with values [{}]", attributeName, c);
c.stream().map(value -> newAttributeValue(value, defaultElementName)).forEach(attributeList::add);
} else {
LOGGER.debug("Generating SAML attribute [{}] with value [{}]", attributeName, attributeValue);
attributeList.add(newAttributeValue(attributeValue, defaultElementName));
}
}
/**
* Marshal the saml xml object to raw xml.
*
* @param object the object
* @param writer the writer
* @return the xml string
*/
public String marshalSamlXmlObject(final XMLObject object, final StringWriter writer) {
try {
final MarshallerFactory marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory();
final Marshaller marshaller = marshallerFactory.getMarshaller(object);
if (marshaller == null) {
throw new IllegalArgumentException("Cannot obtain marshaller for object " + object.getElementQName());
}
final Element element = marshaller.marshall(object);
element.setAttributeNS(NAMESPACE_URI, "xmlns", SAMLConstants.SAML20_NS);
element.setAttributeNS(NAMESPACE_URI, "xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#");
final TransformerFactory transFactory = TransformerFactory.newInstance();
final Transformer transformer = transFactory.newTransformer();
transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
transformer.setOutputProperty(OutputKeys.INDENT, "yes");
transformer.transform(new DOMSource(element), new StreamResult(writer));
return writer.toString();
} catch (final Exception e) {
throw new IllegalStateException("An error has occurred while marshalling SAML object to xml", e);
}
}
/**
* Sign SAML response.
*
* @param samlResponse the SAML response
* @param privateKey the private key
* @param publicKey the public key
* @return the response
*/
public static String signSamlResponse(final String samlResponse, final PrivateKey privateKey, final PublicKey publicKey) {
final Document doc = constructDocumentFromXml(samlResponse);
if (doc != null) {
final org.jdom.Element signedElement = signSamlElement(doc.getRootElement(),
privateKey, publicKey);
doc.setRootElement((org.jdom.Element) signedElement.detach());
return new XMLOutputter().outputString(doc);
}
throw new RuntimeException("Error signing SAML Response: Null document");
}
/**
* Construct document from xml string.
*
* @param xmlString the xml string
* @return the document
*/
public static Document constructDocumentFromXml(final String xmlString) {
try {
final SAXBuilder builder = new SAXBuilder();
builder.setFeature("http://xml.org/sax/features/external-general-entities", false);
builder.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
return builder
.build(new ByteArrayInputStream(xmlString.getBytes(Charset.defaultCharset())));
} catch (final Exception e) {
return null;
}
}
/**
* Sign SAML element.
*
* @param element the element
* @param privKey the priv key
* @param pubKey the pub key
* @return the element
*/
private static org.jdom.Element signSamlElement(final org.jdom.Element element, final PrivateKey privKey, final PublicKey pubKey) {
try {
final String providerName = System.getProperty("jsr105Provider", SIGNATURE_FACTORY_PROVIDER_CLASS);
final XMLSignatureFactory sigFactory = XMLSignatureFactory
.getInstance("DOM", (Provider) Class.forName(providerName).newInstance());
final List<Transform> envelopedTransform = Collections.singletonList(sigFactory.newTransform(Transform.ENVELOPED,
(TransformParameterSpec) null));
final Reference ref = sigFactory.newReference(StringUtils.EMPTY, sigFactory
.newDigestMethod(DigestMethod.SHA1, null), envelopedTransform, null, null);
// Create the SignatureMethod based on the type of key
final SignatureMethod signatureMethod;
final String algorithm = pubKey.getAlgorithm();
switch (algorithm) {
case "DSA":
signatureMethod = sigFactory.newSignatureMethod(SignatureMethod.DSA_SHA1, null);
break;
case "RSA":
signatureMethod = sigFactory.newSignatureMethod(SignatureMethod.RSA_SHA1, null);
break;
default:
throw new RuntimeException("Error signing SAML element: Unsupported type of key");
}
final CanonicalizationMethod canonicalizationMethod = sigFactory
.newCanonicalizationMethod(
CanonicalizationMethod.INCLUSIVE_WITH_COMMENTS,
(C14NMethodParameterSpec) null);
// Create the SignedInfo
final SignedInfo signedInfo = sigFactory.newSignedInfo(
canonicalizationMethod, signatureMethod, Collections.singletonList(ref));
// Create a KeyValue containing the DSA or RSA PublicKey
final KeyInfoFactory keyInfoFactory = sigFactory.getKeyInfoFactory();
final KeyValue keyValuePair = keyInfoFactory.newKeyValue(pubKey);
// Create a KeyInfo and add the KeyValue to it
final KeyInfo keyInfo = keyInfoFactory.newKeyInfo(Collections.singletonList(keyValuePair));
// Convert the JDOM document to w3c (Java XML signature API requires w3c representation)
final Element w3cElement = toDom(element);
// Create a DOMSignContext and specify the DSA/RSA PrivateKey and
// location of the resulting XMLSignature's parent element
final DOMSignContext dsc = new DOMSignContext(privKey, w3cElement);
final Node xmlSigInsertionPoint = getXmlSignatureInsertLocation(w3cElement);
dsc.setNextSibling(xmlSigInsertionPoint);
// Marshal, generate (and sign) the enveloped signature
final XMLSignature signature = sigFactory.newXMLSignature(signedInfo, keyInfo);
signature.sign(dsc);
return toJdom(w3cElement);
} catch (final Exception e) {
throw new RuntimeException("Error signing SAML element: " + e.getMessage(), e);
}
}
/**
* Gets the xml signature insert location.
*
* @param elem the elem
* @return the xml signature insert location
*/
private static Node getXmlSignatureInsertLocation(final Element elem) {
final Node insertLocation;
NodeList nodeList = elem.getElementsByTagNameNS(SAMLConstants.SAML20P_NS, "Extensions");
if (nodeList.getLength() != 0) {
insertLocation = nodeList.item(nodeList.getLength() - 1);
} else {
nodeList = elem.getElementsByTagNameNS(SAMLConstants.SAML20P_NS, "Status");
insertLocation = nodeList.item(nodeList.getLength() - 1);
}
return insertLocation;
}
/**
* Convert the received jdom element to an Element.
*
* @param element the element
* @return the org.w3c.dom. element
*/
private static Element toDom(final org.jdom.Element element) {
return toDom(element.getDocument()).getDocumentElement();
}
/**
* Convert the received jdom doc to a Document element.
*
* @param doc the doc
* @return the org.w3c.dom. document
*/
private static org.w3c.dom.Document toDom(final Document doc) {
try {
final XMLOutputter xmlOutputter = new XMLOutputter();
final StringWriter elemStrWriter = new StringWriter();
xmlOutputter.output(doc, elemStrWriter);
final byte[] xmlBytes = elemStrWriter.toString().getBytes(Charset.defaultCharset());
final DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
dbf.setNamespaceAware(true);
dbf.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);
dbf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
dbf.setFeature("http://apache.org/xml/features/validation/schema/normalized-value", false);
dbf.setFeature("http://javax.xml.XMLConstants/feature/secure-processing", true);
dbf.setFeature("http://xml.org/sax/features/external-general-entities", false);
dbf.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
return dbf.newDocumentBuilder().parse(new ByteArrayInputStream(xmlBytes));
} catch (final Exception e) {
LOGGER.trace(e.getMessage(), e);
return null;
}
}
/**
* Convert to a jdom element.
*
* @param e the e
* @return the element
*/
private static org.jdom.Element toJdom(final Element e) {
return new DOMBuilder().build(e);
}
}