package org.apereo.cas.support.wsfederation; import com.google.common.base.Throwables; import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import org.apereo.cas.support.saml.OpenSamlConfigBean; import org.apereo.cas.support.saml.SamlUtils; import org.apereo.cas.support.wsfederation.authentication.principal.WsFederationCredential; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.bouncycastle.jce.provider.X509CertParser; import org.bouncycastle.jce.provider.X509CertificateObject; import org.bouncycastle.openssl.PEMDecryptorProvider; import org.bouncycastle.openssl.PEMEncryptedKeyPair; import org.bouncycastle.openssl.PEMKeyPair; import org.bouncycastle.openssl.PEMParser; import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder; import org.opensaml.core.criterion.EntityIdCriterion; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.io.Unmarshaller; import org.opensaml.core.xml.io.UnmarshallerFactory; import org.opensaml.core.xml.schema.XSAny; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.criterion.EntityRoleCriterion; import org.opensaml.saml.criterion.ProtocolCriterion; import org.opensaml.saml.saml1.core.Assertion; import org.opensaml.saml.saml1.core.Conditions; import org.opensaml.saml.saml2.encryption.Decrypter; import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; import org.opensaml.saml.saml2.metadata.IDPSSODescriptor; import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; import org.opensaml.security.SecurityException; import org.opensaml.security.credential.Credential; import org.opensaml.security.credential.CredentialResolver; import org.opensaml.security.credential.UsageType; import org.opensaml.security.credential.impl.StaticCredentialResolver; import org.opensaml.security.criteria.UsageCriterion; import org.opensaml.security.x509.BasicX509Credential; import org.opensaml.soap.wsfed.RequestSecurityTokenResponse; import org.opensaml.soap.wsfed.RequestedSecurityToken; import org.opensaml.xmlsec.encryption.EncryptedData; import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver; import org.opensaml.xmlsec.signature.support.SignatureException; import org.opensaml.xmlsec.signature.support.SignaturePrevalidator; import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.w3c.dom.Document; import org.w3c.dom.Element; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.security.KeyPair; import java.security.Security; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; /** * Helper class that does the heavy lifting with the openSaml library. * * @author John Gasper * @since 4.2.0 */ public class WsFederationHelper { private static final Logger LOGGER = LoggerFactory.getLogger(WsFederationHelper.class); private OpenSamlConfigBean configBean; /** * private constructor. */ public WsFederationHelper() { } /** * createCredentialFromToken converts a SAML 1.1 assertion to a WSFederationCredential. * * @param assertion the provided assertion * @return an equivalent credential. */ public WsFederationCredential createCredentialFromToken(final Assertion assertion) { final ZonedDateTime retrievedOn = ZonedDateTime.now(); LOGGER.debug("Retrieved on [{}]", retrievedOn); final WsFederationCredential credential = new WsFederationCredential(); credential.setRetrievedOn(retrievedOn); credential.setId(assertion.getID()); credential.setIssuer(assertion.getIssuer()); credential.setIssuedOn(ZonedDateTime.parse(assertion.getIssueInstant().toDateTimeISO().toString())); final Conditions conditions = assertion.getConditions(); if (conditions != null) { credential.setNotBefore(ZonedDateTime.parse(conditions.getNotBefore().toDateTimeISO().toString())); credential.setNotOnOrAfter(ZonedDateTime.parse(conditions.getNotOnOrAfter().toDateTimeISO().toString())); if (!conditions.getAudienceRestrictionConditions().isEmpty()) { credential.setAudience(conditions.getAudienceRestrictionConditions().get(0).getAudiences().get(0).getUri()); } } if (!assertion.getAuthenticationStatements().isEmpty()) { credential.setAuthenticationMethod(assertion.getAuthenticationStatements().get(0).getAuthenticationMethod()); } //retrieve an attributes from the assertion final HashMap<String, List<Object>> attributes = new HashMap<>(); assertion.getAttributeStatements().stream().flatMap(attributeStatement -> attributeStatement.getAttributes().stream()).forEach(item -> { LOGGER.debug("Processed attribute: [{}]", item.getAttributeName()); final List<Object> itemList = IntStream.range(0, item.getAttributeValues().size()) .mapToObj(i -> ((XSAny) item.getAttributeValues().get(i)).getTextContent()) .collect(Collectors.toList()); if (!itemList.isEmpty()) { attributes.put(item.getAttributeName(), itemList); } }); credential.setAttributes(attributes); LOGGER.debug("Credential: [{}]", credential); return credential; } /** * parseTokenFromString converts a raw wresult and extracts it into an assertion. * * @param wresult the raw token returned by the IdP * @param config the config * @return an assertion */ public Assertion parseTokenFromString(final String wresult, final WsFederationConfiguration config) { LOGGER.debug("Result token received from ADFS is [{}]", wresult); try (InputStream in = new ByteArrayInputStream(wresult.getBytes(StandardCharsets.UTF_8))) { LOGGER.debug("Parsing token into a document"); final Document document = configBean.getParserPool().parse(in); final Element metadataRoot = document.getDocumentElement(); final UnmarshallerFactory unmarshallerFactory = configBean.getUnmarshallerFactory(); final Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(metadataRoot); if (unmarshaller == null) { throw new IllegalArgumentException("Unmarshaller for the metadata root element cannot be determined"); } LOGGER.debug("Unmarshalling the document into a security token response"); final RequestSecurityTokenResponse rsToken = (RequestSecurityTokenResponse) unmarshaller.unmarshall(metadataRoot); if (rsToken == null || rsToken.getRequestedSecurityToken() == null) { throw new IllegalArgumentException("Request security token response is null"); } //Get our SAML token LOGGER.debug("Locating list of requested security tokens"); final List<RequestedSecurityToken> rst = rsToken.getRequestedSecurityToken(); if (rst.isEmpty()) { throw new IllegalArgumentException("No requested security token response is provided in the response"); } LOGGER.debug("Locating the first occurrence of a requested security token in the list"); final RequestedSecurityToken reqToken = rst.get(0); if (reqToken.getSecurityTokens() == null || reqToken.getSecurityTokens().isEmpty()) { throw new IllegalArgumentException("Requested security token response is not carrying any security tokens"); } Assertion assertion = null; LOGGER.debug("Locating the first occurrence of a security token from the requested security token"); XMLObject securityToken = reqToken.getSecurityTokens().get(0); if (securityToken instanceof EncryptedData) { try { LOGGER.debug("Security token is encrypted. Attempting to decrypt to extract the assertion"); final EncryptedData encryptedData = EncryptedData.class.cast(securityToken); final Decrypter decrypter = buildAssertionDecrypter(config); LOGGER.debug("Built an instance of [{}]", decrypter.getClass().getName()); securityToken = decrypter.decryptData(encryptedData); } catch (final Exception e) { throw new IllegalArgumentException("Unable to decrypt security token", e); } } if (securityToken instanceof Assertion) { LOGGER.debug("Security token is an assertion."); assertion = Assertion.class.cast(securityToken); } if (assertion == null) { throw new IllegalArgumentException("Could not extract or decrypt an assertion based on the security token provided"); } LOGGER.debug("Extracted assertion successfully: [{}]", assertion); return assertion; } catch (final Exception ex) { LOGGER.warn(ex.getMessage()); return null; } } /** * validateSignature checks to see if the signature on an assertion is valid. * * @param assertion a provided assertion * @param wsFederationConfiguration WS-Fed configuration provided. * @return true if the assertion's signature is valid, otherwise false */ public boolean validateSignature(final Assertion assertion, final WsFederationConfiguration wsFederationConfiguration) { if (assertion == null) { LOGGER.warn("No assertion was provided to validate signatures"); return false; } boolean valid = false; if (assertion.getSignature() != null) { final SignaturePrevalidator validator = new SAMLSignatureProfileValidator(); try { validator.validate(assertion.getSignature()); final CriteriaSet criteriaSet = new CriteriaSet(); criteriaSet.add(new UsageCriterion(UsageType.SIGNING)); criteriaSet.add(new EntityRoleCriterion(IDPSSODescriptor.DEFAULT_ELEMENT_NAME)); criteriaSet.add(new ProtocolCriterion(SAMLConstants.SAML20P_NS)); criteriaSet.add(new EntityIdCriterion(wsFederationConfiguration.getIdentityProviderIdentifier())); try { final SignatureTrustEngine engine = buildSignatureTrustEngine(wsFederationConfiguration); valid = engine.validate(assertion.getSignature(), criteriaSet); } catch (final SecurityException e) { LOGGER.warn(e.getMessage(), e); } finally { if (!valid) { LOGGER.warn("Signature doesn't match any signing credential."); } } } catch (final SignatureException e) { LOGGER.warn("Failed to validate assertion signature", e); } } SamlUtils.logSamlObject(this.configBean, assertion); return valid; } /** * Build signature trust engine. * * @param wsFederationConfiguration the ws federation configuration * @return the signature trust engine */ private static SignatureTrustEngine buildSignatureTrustEngine(final WsFederationConfiguration wsFederationConfiguration) { try { final CredentialResolver resolver = new StaticCredentialResolver(wsFederationConfiguration.getSigningCertificates()); final KeyInfoCredentialResolver keyResolver = new StaticKeyInfoCredentialResolver(wsFederationConfiguration.getSigningCertificates()); return new ExplicitKeySignatureTrustEngine(resolver, keyResolver); } catch (final Exception e) { throw Throwables.propagate(e); } } public void setConfigBean(final OpenSamlConfigBean configBean) { this.configBean = configBean; } private static Credential getEncryptionCredential(final WsFederationConfiguration config) { try { // This will need to contain the private keypair in PEM format LOGGER.debug("Locating encryption credential private key [{}]", config.getEncryptionPrivateKey()); final BufferedReader br = new BufferedReader(new InputStreamReader( config.getEncryptionPrivateKey().getInputStream(), StandardCharsets.UTF_8)); Security.addProvider(new BouncyCastleProvider()); LOGGER.debug("Parsing credential private key"); final PEMParser pemParser = new PEMParser(br); final Object privateKeyPemObject = pemParser.readObject(); final JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider(new BouncyCastleProvider()); final KeyPair kp; if (privateKeyPemObject instanceof PEMEncryptedKeyPair) { LOGGER.debug("Encryption private key is an encrypted keypair"); final PEMEncryptedKeyPair ckp = (PEMEncryptedKeyPair) privateKeyPemObject; final PEMDecryptorProvider decProv = new JcePEMDecryptorProviderBuilder() .build(config.getEncryptionPrivateKeyPassword().toCharArray()); LOGGER.debug("Attempting to decrypt the encrypted keypair based on the provided encryption private key password"); kp = converter.getKeyPair(ckp.decryptKeyPair(decProv)); } else { LOGGER.debug("Extracting a keypair from the private key"); kp = converter.getKeyPair((PEMKeyPair) privateKeyPemObject); } final X509CertParser certParser = new X509CertParser(); // This is the certificate shared with ADFS in DER format, i.e certificate.crt LOGGER.debug("Locating encryption certificate [{}]", config.getEncryptionCertificate()); certParser.engineInit(config.getEncryptionCertificate().getInputStream()); LOGGER.debug("Invoking certificate engine to parse the certificate [{}]", config.getEncryptionCertificate()); final X509CertificateObject cert = (X509CertificateObject) certParser.engineRead(); LOGGER.debug("Creating final credential based on the certificate [{}] and the private key", cert.getIssuerDN()); return new BasicX509Credential(cert, kp.getPrivate()); } catch (final Exception e) { throw Throwables.propagate(e); } } private static Decrypter buildAssertionDecrypter(final WsFederationConfiguration config) { final List<EncryptedKeyResolver> list = new ArrayList<>(); list.add(new InlineEncryptedKeyResolver()); list.add(new EncryptedElementTypeEncryptedKeyResolver()); list.add(new SimpleRetrievalMethodEncryptedKeyResolver()); LOGGER.debug("Built a list of encrypted key resolvers: [{}]", list); final ChainingEncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(list); LOGGER.debug("Building credential instance to decrypt data"); final Credential encryptionCredential = getEncryptionCredential(config); final KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver(encryptionCredential); final Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); decrypter.setRootInNewDocument(true); return decrypter; } public OpenSamlConfigBean getConfigBean() { return configBean; } }