package org.apereo.cas.support.saml; import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import org.apache.commons.lang3.StringUtils; import org.apereo.cas.services.RegisteredService; import org.apereo.cas.services.ServicesManager; import org.apereo.cas.support.saml.services.SamlRegisteredService; import org.apereo.cas.support.saml.services.idp.metadata.SamlRegisteredServiceServiceProviderMetadataFacade; import org.apereo.cas.support.saml.services.idp.metadata.cache.SamlRegisteredServiceCachingMetadataResolver; import org.opensaml.core.criterion.EntityIdCriterion; import org.opensaml.messaging.context.MessageContext; import org.opensaml.saml.common.messaging.context.SAMLEndpointContext; import org.opensaml.saml.common.messaging.context.SAMLPeerEntityContext; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.criterion.BindingCriterion; import org.opensaml.saml.criterion.EntityRoleCriterion; import org.opensaml.saml.metadata.resolver.ChainingMetadataResolver; import org.opensaml.saml.metadata.resolver.MetadataResolver; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.RequestAbstractType; import org.opensaml.saml.saml2.metadata.AssertionConsumerService; import org.opensaml.saml.saml2.metadata.Endpoint; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.saml.saml2.metadata.SPSSODescriptor; import org.opensaml.saml.saml2.metadata.impl.AssertionConsumerServiceBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; /** * This is {@link SamlIdPUtils}. * * @author Misagh Moayyed * @since 5.0.0 */ public final class SamlIdPUtils { private static final Logger LOGGER = LoggerFactory.getLogger(SamlIdPUtils.class); private SamlIdPUtils() { } /** * Prepare peer entity saml endpoint. * * @param outboundContext the outbound context * @param adaptor the adaptor * @param binding the binding * @throws SamlException the saml exception */ public static void preparePeerEntitySamlEndpointContext(final MessageContext outboundContext, final SamlRegisteredServiceServiceProviderMetadataFacade adaptor, final String binding) throws SamlException { if (!adaptor.containsAssertionConsumerServices()) { throw new SamlException("No assertion consumer service could be found for entity " + adaptor.getEntityId()); } final SAMLPeerEntityContext peerEntityContext = outboundContext.getSubcontext(SAMLPeerEntityContext.class, true); if (peerEntityContext == null) { throw new SamlException("SAMLPeerEntityContext could not be defined for entity " + adaptor.getEntityId()); } final SAMLEndpointContext endpointContext = peerEntityContext.getSubcontext(SAMLEndpointContext.class, true); if (endpointContext == null) { throw new SamlException("SAMLEndpointContext could not be defined for entity " + adaptor.getEntityId()); } final Endpoint endpoint = adaptor.getAssertionConsumerService(binding); if (StringUtils.isBlank(endpoint.getBinding()) || StringUtils.isBlank(endpoint.getLocation())) { throw new SamlException("Assertion consumer service does not define a binding or location for " + adaptor.getEntityId()); } LOGGER.debug("Configured peer entity endpoint to be [{}] with binding [{}]", endpoint.getLocation(), endpoint.getBinding()); endpointContext.setEndpoint(endpoint); } /** * Gets chaining metadata resolver for all saml services. * * @param servicesManager the services manager * @param entityID the entity id * @param resolver the resolver * @return the chaining metadata resolver for all saml services */ public static MetadataResolver getMetadataResolverForAllSamlServices(final ServicesManager servicesManager, final String entityID, final SamlRegisteredServiceCachingMetadataResolver resolver) { try { final Collection<RegisteredService> registeredServices = servicesManager.findServiceBy(SamlRegisteredService.class::isInstance); final List<MetadataResolver> resolvers; final ChainingMetadataResolver chainingMetadataResolver = new ChainingMetadataResolver(); resolvers = registeredServices.stream() .filter(SamlRegisteredService.class::isInstance) .map(SamlRegisteredService.class::cast) .map(s -> SamlRegisteredServiceServiceProviderMetadataFacade.get(resolver, s, entityID)) .filter(Optional::isPresent) .map(Optional::get) .map(SamlRegisteredServiceServiceProviderMetadataFacade::getMetadataResolver) .collect(Collectors.toList()); LOGGER.debug("Located [{}] metadata resolvers to match against [{}]", resolvers, entityID); chainingMetadataResolver.setResolvers(resolvers); chainingMetadataResolver.setId(entityID); chainingMetadataResolver.initialize(); return chainingMetadataResolver; } catch (final Exception e) { throw new RuntimeException(new SamlException(e.getMessage(), e)); } } /** * Gets assertion consumer service for. * * @param authnRequest the authn request * @param servicesManager the services manager * @param resolver the resolver * @return the assertion consumer service for */ public static AssertionConsumerService getAssertionConsumerServiceFor(final AuthnRequest authnRequest, final ServicesManager servicesManager, final SamlRegisteredServiceCachingMetadataResolver resolver) { try { final AssertionConsumerService acs = new AssertionConsumerServiceBuilder().buildObject(); if (authnRequest.getAssertionConsumerServiceIndex() != null) { final String issuer = getIssuerFromSamlRequest(authnRequest); final MetadataResolver samlResolver = getMetadataResolverForAllSamlServices(servicesManager, issuer, resolver); final CriteriaSet criteriaSet = new CriteriaSet(); criteriaSet.add(new EntityIdCriterion(issuer)); criteriaSet.add(new EntityRoleCriterion(SPSSODescriptor.DEFAULT_ELEMENT_NAME)); criteriaSet.add(new BindingCriterion(Arrays.asList(SAMLConstants.SAML2_POST_BINDING_URI))); final Iterable<EntityDescriptor> it = samlResolver.resolve(criteriaSet); it.forEach(entityDescriptor -> { final SPSSODescriptor spssoDescriptor = entityDescriptor.getSPSSODescriptor(SAMLConstants.SAML20P_NS); final List<AssertionConsumerService> acsEndpoints = spssoDescriptor.getAssertionConsumerServices(); if (acsEndpoints.isEmpty()) { throw new RuntimeException("Metadata resolved for entity id " + issuer + " has no defined ACS endpoints"); } final int acsIndex = authnRequest.getAssertionConsumerServiceIndex(); if (acsIndex + 1 > acsEndpoints.size()) { throw new RuntimeException("AssertionConsumerService index specified in the request " + acsIndex + " is invalid " + "since the total endpoints available to " + issuer + " is " + acsEndpoints.size()); } final AssertionConsumerService foundAcs = acsEndpoints.get(acsIndex); acs.setBinding(foundAcs.getBinding()); acs.setLocation(foundAcs.getLocation()); acs.setResponseLocation(foundAcs.getResponseLocation()); acs.setIndex(acsIndex); }); } else { acs.setBinding(authnRequest.getProtocolBinding()); acs.setLocation(authnRequest.getAssertionConsumerServiceURL()); acs.setResponseLocation(authnRequest.getAssertionConsumerServiceURL()); acs.setIndex(0); acs.setIsDefault(Boolean.TRUE); } LOGGER.debug("Resolved AssertionConsumerService from the request is [{}]", acs); if (StringUtils.isBlank(acs.getBinding())) { throw new SamlException("AssertionConsumerService has no protocol binding defined"); } if (StringUtils.isBlank(acs.getLocation()) && StringUtils.isBlank(acs.getResponseLocation())) { throw new SamlException("AssertionConsumerService has no location or response location defined"); } return acs; } catch (final Exception e) { throw new RuntimeException(new SamlException(e.getMessage(), e)); } } /** * Gets issuer from saml request. * * @param request the request * @return the issuer from saml request */ public static String getIssuerFromSamlRequest(final RequestAbstractType request) { return request.getIssuer().getValue(); } }