package org.apereo.cas.support.saml.web.idp.profile; import com.google.common.base.Throwables; import net.shibboleth.utilities.java.support.net.URLBuilder; import net.shibboleth.utilities.java.support.xml.ParserPool; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apereo.cas.CasProtocolConstants; import org.apereo.cas.authentication.AuthenticationSystemSupport; import org.apereo.cas.authentication.principal.Service; import org.apereo.cas.authentication.principal.ServiceFactory; import org.apereo.cas.authentication.principal.WebApplicationService; import org.apereo.cas.services.RegexRegisteredService; import org.apereo.cas.services.RegisteredService; import org.apereo.cas.services.ServicesManager; import org.apereo.cas.services.UnauthorizedServiceException; import org.apereo.cas.support.saml.OpenSamlConfigBean; import org.apereo.cas.support.saml.SamlException; import org.apereo.cas.support.saml.SamlIdPConstants; import org.apereo.cas.support.saml.SamlIdPUtils; import org.apereo.cas.support.saml.SamlProtocolConstants; import org.apereo.cas.support.saml.SamlUtils; import org.apereo.cas.support.saml.services.SamlRegisteredService; import org.apereo.cas.support.saml.services.idp.metadata.SamlRegisteredServiceServiceProviderMetadataFacade; import org.apereo.cas.support.saml.services.idp.metadata.cache.SamlRegisteredServiceCachingMetadataResolver; import org.apereo.cas.support.saml.web.idp.profile.builders.SamlProfileObjectBuilder; import org.apereo.cas.support.saml.web.idp.profile.builders.enc.BaseSamlObjectSigner; import org.apereo.cas.support.saml.web.idp.profile.builders.enc.SamlObjectSignatureValidator; import org.apereo.cas.util.EncodingUtils; import org.apereo.cas.web.support.WebUtils; import org.jasig.cas.client.authentication.AuthenticationRedirectStrategy; import org.jasig.cas.client.authentication.DefaultAuthenticationRedirectStrategy; import org.jasig.cas.client.util.CommonUtils; import org.jasig.cas.client.validation.Assertion; import org.opensaml.core.xml.util.XMLObjectSupport; import org.opensaml.messaging.context.MessageContext; import org.opensaml.messaging.decoder.servlet.BaseHttpServletRequestXMLMessageDecoder; import org.opensaml.saml.common.SAMLException; import org.opensaml.saml.common.SAMLObject; import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.saml.common.binding.SAMLBindingSupport; import org.opensaml.saml.saml2.core.AuthnContextClassRef; import org.opensaml.saml.saml2.core.AuthnRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.servlet.ModelAndView; import javax.annotation.PostConstruct; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.ByteArrayInputStream; import java.io.StringWriter; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; import java.util.Optional; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; /** * A parent controller to handle SAML requests. * Specific profile endpoints are handled by extensions. * This parent provides the necessary ops for profile endpoint * controllers to respond to end points. * * @author Misagh Moayyed * @since 5.0.0 */ @Controller public abstract class AbstractSamlProfileHandlerController { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractSamlProfileHandlerController.class); /** * Authentication support to handle credentials and authn subsystem calls. */ protected AuthenticationSystemSupport authenticationSystemSupport; /** * The Saml object signer. */ protected BaseSamlObjectSigner samlObjectSigner; /** * Signature validator. */ protected SamlObjectSignatureValidator samlObjectSignatureValidator; /** * The Parser pool. */ protected ParserPool parserPool; /** * Callback service. */ protected Service callbackService; /** * The Services manager. */ protected ServicesManager servicesManager; /** * The Web application service factory. */ protected ServiceFactory<WebApplicationService> webApplicationServiceFactory; /** * The Saml registered service caching metadata resolver. */ protected SamlRegisteredServiceCachingMetadataResolver samlRegisteredServiceCachingMetadataResolver; /** * The Config bean. */ protected OpenSamlConfigBean configBean; /** * The Response builder. */ protected SamlProfileObjectBuilder<? extends SAMLObject> responseBuilder; /** * Maps authentication contexts to what CAS can support. */ protected Set<String> authenticationContextClassMappings = new TreeSet<>(); /** * Server Prefix. **/ protected String serverPrefix; /** * Server name. **/ protected String serverName; /** * authn context request parameter name. **/ protected String authenticationContextRequestParameter; /** * Server login URL. **/ protected String loginUrl; /** * Server logout URL. **/ protected String logoutUrl; /** * Force SLO requests. **/ protected boolean forceSignedLogoutRequests; /** * Disable SLO. **/ protected boolean singleLogoutCallbacksDisabled; /** * Instantiates a new Abstract saml profile handler controller. * * @param samlObjectSigner the saml object signer * @param parserPool the parser pool * @param authenticationSystemSupport the authentication system support * @param servicesManager the services manager * @param webApplicationServiceFactory the web application service factory * @param samlRegisteredServiceCachingMetadataResolver the saml registered service caching metadata resolver * @param configBean the config bean * @param responseBuilder the response builder * @param authenticationContextClassMappings the authentication context class mappings * @param serverPrefix the server prefix * @param serverName the server name * @param authenticationContextRequestParameter the authentication context request parameter * @param loginUrl the login url * @param logoutUrl the logout url * @param forceSignedLogoutRequests the force signed logout requests * @param singleLogoutCallbacksDisabled the single logout callbacks disabled * @param samlObjectSignatureValidator the saml object signature validator */ public AbstractSamlProfileHandlerController(final BaseSamlObjectSigner samlObjectSigner, final ParserPool parserPool, final AuthenticationSystemSupport authenticationSystemSupport, final ServicesManager servicesManager, final ServiceFactory<WebApplicationService> webApplicationServiceFactory, final SamlRegisteredServiceCachingMetadataResolver samlRegisteredServiceCachingMetadataResolver, final OpenSamlConfigBean configBean, final SamlProfileObjectBuilder<? extends SAMLObject> responseBuilder, final Set<String> authenticationContextClassMappings, final String serverPrefix, final String serverName, final String authenticationContextRequestParameter, final String loginUrl, final String logoutUrl, final boolean forceSignedLogoutRequests, final boolean singleLogoutCallbacksDisabled, final SamlObjectSignatureValidator samlObjectSignatureValidator) { this.samlObjectSigner = samlObjectSigner; this.parserPool = parserPool; this.servicesManager = servicesManager; this.webApplicationServiceFactory = webApplicationServiceFactory; this.samlRegisteredServiceCachingMetadataResolver = samlRegisteredServiceCachingMetadataResolver; this.configBean = configBean; this.responseBuilder = responseBuilder; this.authenticationContextClassMappings = authenticationContextClassMappings; this.serverPrefix = serverPrefix; this.serverName = serverName; this.authenticationContextRequestParameter = authenticationContextRequestParameter; this.loginUrl = loginUrl; this.logoutUrl = logoutUrl; this.forceSignedLogoutRequests = forceSignedLogoutRequests; this.singleLogoutCallbacksDisabled = singleLogoutCallbacksDisabled; this.authenticationSystemSupport = authenticationSystemSupport; this.samlObjectSignatureValidator = samlObjectSignatureValidator; } /** * Post constructor placeholder for additional * extensions. This method is called after * the object has completely initialized itself. */ @PostConstruct protected void initialize() { this.callbackService = registerCallback(SamlIdPConstants.ENDPOINT_SAML2_SSO_PROFILE_POST_CALLBACK); } /** * Gets saml metadata adaptor for service. * * @param registeredService the registered service * @param authnRequest the authn request * @return the saml metadata adaptor for service */ protected Optional<SamlRegisteredServiceServiceProviderMetadataFacade> getSamlMetadataFacadeFor(final SamlRegisteredService registeredService, final AuthnRequest authnRequest) { return SamlRegisteredServiceServiceProviderMetadataFacade.get(this.samlRegisteredServiceCachingMetadataResolver, registeredService, authnRequest); } /** * Gets saml metadata adaptor for service. * * @param registeredService the registered service * @param entityId the entity id * @return the saml metadata adaptor for service */ protected Optional<SamlRegisteredServiceServiceProviderMetadataFacade> getSamlMetadataFacadeFor(final SamlRegisteredService registeredService, final String entityId) { return SamlRegisteredServiceServiceProviderMetadataFacade .get(this.samlRegisteredServiceCachingMetadataResolver, registeredService, entityId); } /** * Gets registered service and verify. * * @param serviceId the service id * @return the registered service and verify */ protected SamlRegisteredService verifySamlRegisteredService(final String serviceId) { if (StringUtils.isBlank(serviceId)) { throw new UnauthorizedServiceException(UnauthorizedServiceException.CODE_UNAUTHZ_SERVICE, "Could not verify/locate SAML registered service since no serviceId is provided"); } LOGGER.debug("Checking service access in CAS service registry for [{}]", serviceId); final RegisteredService registeredService = this.servicesManager.findServiceBy(this.webApplicationServiceFactory.createService(serviceId)); if (registeredService == null || !registeredService.getAccessStrategy().isServiceAccessAllowed()) { LOGGER.warn("[{}] is not found in the registry or service access is denied. Ensure service is registered in service registry", serviceId); throw new UnauthorizedServiceException(UnauthorizedServiceException.CODE_UNAUTHZ_SERVICE); } if (registeredService instanceof SamlRegisteredService) { final SamlRegisteredService samlRegisteredService = (SamlRegisteredService) registeredService; LOGGER.debug("Located SAML service in the registry as [{}] with the metadata location of [{}]", samlRegisteredService.getServiceId(), samlRegisteredService.getMetadataLocation()); return samlRegisteredService; } LOGGER.error("CAS has found a match for service [{}] in registry but the match is not defined as a SAML service", serviceId); throw new UnauthorizedServiceException(UnauthorizedServiceException.CODE_UNAUTHZ_SERVICE); } /** * Initialize callback service. * * @param callbackUrl the callback url * @return the service */ protected Service registerCallback(final String callbackUrl) { final Service callbackService = this.webApplicationServiceFactory.createService( this.serverPrefix.concat(callbackUrl.concat(".+"))); if (!this.servicesManager.matchesExistingService(callbackService)) { LOGGER.debug("Initializing callback service [{}]", callbackService); final RegexRegisteredService service = new RegexRegisteredService(); service.setId(Math.abs(new SecureRandom().nextLong())); service.setEvaluationOrder(0); service.setName(service.getClass().getSimpleName()); service.setDescription("SAML Authentication Request"); service.setServiceId(callbackService.getId()); LOGGER.debug("Saving callback service [{}] into the registry", service); this.servicesManager.save(service); this.servicesManager.load(); } return callbackService; } /** * Retrieve authn request authn request. * * @param request the request * @return the authn request * @throws Exception the exception */ protected AuthnRequest retrieveSamlAuthenticationRequestFromHttpRequest(final HttpServletRequest request) throws Exception { LOGGER.debug("Retrieving authentication request from scope"); final String requestValue = request.getParameter(SamlProtocolConstants.PARAMETER_SAML_REQUEST); if (StringUtils.isBlank(requestValue)) { throw new IllegalArgumentException("SAML request could not be determined from the authentication request"); } final byte[] encodedRequest = EncodingUtils.decodeBase64(requestValue.getBytes(StandardCharsets.UTF_8)); final AuthnRequest authnRequest = (AuthnRequest) XMLObjectSupport.unmarshallFromInputStream(this.configBean.getParserPool(), new ByteArrayInputStream(encodedRequest)); return authnRequest; } /** * Decode authentication request saml object. * * @param request the request * @param decoder the decoder * @param clazz the clazz * @return the saml object */ protected Pair<? extends SignableSAMLObject, MessageContext> decodeSamlContextFromHttpRequest(final HttpServletRequest request, final BaseHttpServletRequestXMLMessageDecoder decoder, final Class<? extends SignableSAMLObject> clazz) { LOGGER.info("Received SAML profile request [{}]", request.getRequestURI()); try { decoder.setHttpServletRequest(request); decoder.setParserPool(this.parserPool); decoder.initialize(); decoder.decode(); final MessageContext messageContext = decoder.getMessageContext(); final SignableSAMLObject object = (SignableSAMLObject) messageContext.getMessage(); if (object == null) { throw new SAMLException("No " + clazz.getName() + " could be found in this request context. Decoder has failed."); } if (!clazz.isAssignableFrom(object.getClass())) { throw new ClassCastException("SAML object [" + object.getClass().getName() + " type does not match " + clazz); } LOGGER.debug("Decoded SAML object [{}] from http request", object.getElementQName()); return Pair.of(object, messageContext); } catch (final Exception e) { throw Throwables.propagate(e); } } /** * Log cas validation assertion. * * @param assertion the assertion */ protected void logCasValidationAssertion(final Assertion assertion) { LOGGER.info("CAS Assertion Valid: [{}]", assertion.isValid()); LOGGER.debug("CAS Assertion Principal: [{}]", assertion.getPrincipal().getName()); LOGGER.debug("CAS Assertion AuthN Date: [{}]", assertion.getAuthenticationDate()); LOGGER.debug("CAS Assertion ValidFrom Date: [{}]", assertion.getValidFromDate()); LOGGER.debug("CAS Assertion ValidUntil Date: [{}]", assertion.getValidUntilDate()); LOGGER.debug("CAS Assertion Attributes: [{}]", assertion.getAttributes()); LOGGER.debug("CAS Assertion Principal Attributes: [{}]", assertion.getPrincipal().getAttributes()); } /** * Redirect request for authentication. * * @param pair the pair * @param request the request * @param response the response * @throws Exception the exception */ protected void issueAuthenticationRequestRedirect(final Pair<? extends SignableSAMLObject, MessageContext> pair, final HttpServletRequest request, final HttpServletResponse response) throws Exception { final AuthnRequest authnRequest = AuthnRequest.class.cast(pair.getLeft()); final String serviceUrl = constructServiceUrl(request, response, pair); LOGGER.debug("Created service url [{}]", serviceUrl); final String initialUrl = CommonUtils.constructRedirectUrl(this.loginUrl, CasProtocolConstants.PARAMETER_SERVICE, serviceUrl, authnRequest.isForceAuthn(), authnRequest.isPassive()); final String urlToRedirectTo = buildRedirectUrlByRequestedAuthnContext(initialUrl, authnRequest, request); LOGGER.debug("Redirecting SAML authN request to [{}]", urlToRedirectTo); final AuthenticationRedirectStrategy authenticationRedirectStrategy = new DefaultAuthenticationRedirectStrategy(); authenticationRedirectStrategy.redirect(request, response, urlToRedirectTo); } /** * Build redirect url by requested authn context. * * @param initialUrl the initial url * @param authnRequest the authn request * @param request the request * @return the redirect url */ protected String buildRedirectUrlByRequestedAuthnContext(final String initialUrl, final AuthnRequest authnRequest, final HttpServletRequest request) { if (authnRequest.getRequestedAuthnContext() == null || authenticationContextClassMappings == null || this.authenticationContextClassMappings.isEmpty()) { return initialUrl; } final TreeMap<String, String> mappings = new TreeMap(); this.authenticationContextClassMappings.stream().map(s -> { final String[] bits = s.split("->"); return Pair.of(bits[0], bits[1]); }).forEach(p -> mappings.put(p.getKey(), p.getValue())); final Optional<AuthnContextClassRef> p = authnRequest.getRequestedAuthnContext().getAuthnContextClassRefs().stream().filter(ref -> { final String clazz = ref.getAuthnContextClassRef(); return mappings.containsKey(clazz); }).findFirst(); if (p.isPresent()) { final String mappedClazz = mappings.get(p.get().getAuthnContextClassRef()); return initialUrl + '&' + this.authenticationContextRequestParameter + '=' + mappedClazz; } return initialUrl; } /** * Construct service url string. * * @param request the request * @param response the response * @param pair the pair * @return the string * @throws SamlException the saml exception */ protected String constructServiceUrl(final HttpServletRequest request, final HttpServletResponse response, final Pair<? extends SignableSAMLObject, MessageContext> pair) throws SamlException { final AuthnRequest authnRequest = AuthnRequest.class.cast(pair.getLeft()); final MessageContext messageContext = pair.getRight(); try (StringWriter writer = SamlUtils.transformSamlObject(this.configBean, authnRequest)) { final URLBuilder builder = new URLBuilder(this.callbackService.getId()); builder.getQueryParams().add( new net.shibboleth.utilities.java.support.collection.Pair<>(SamlProtocolConstants.PARAMETER_ENTITY_ID, SamlIdPUtils.getIssuerFromSamlRequest(authnRequest))); final String samlRequest = EncodingUtils.encodeBase64(writer.toString().getBytes(StandardCharsets.UTF_8)); builder.getQueryParams().add( new net.shibboleth.utilities.java.support.collection.Pair<>(SamlProtocolConstants.PARAMETER_SAML_REQUEST, samlRequest)); builder.getQueryParams().add( new net.shibboleth.utilities.java.support.collection.Pair<>(SamlProtocolConstants.PARAMETER_SAML_RELAY_STATE, SAMLBindingSupport.getRelayState(messageContext))); final String url = builder.buildURL(); LOGGER.debug("Built service callback url [{}]", url); return CommonUtils.constructServiceUrl(request, response, url, this.serverName, CasProtocolConstants.PARAMETER_SERVICE, CasProtocolConstants.PARAMETER_TICKET, false); } catch (final Exception e) { throw new SamlException(e.getMessage(), e); } } /** * Initiate authentication request. * * @param pair the pair * @param response the response * @param request the request * @throws Exception the exception */ protected void initiateAuthenticationRequest(final Pair<? extends SignableSAMLObject, MessageContext> pair, final HttpServletResponse response, final HttpServletRequest request) throws Exception { verifySamlAuthenticationRequest(pair, request); issueAuthenticationRequestRedirect(pair, request, response); } /** * Verify saml authentication request. * * @param authenticationContext the pair * @param request the request * @return the pair * @throws Exception the exception */ protected Pair<SamlRegisteredService, SamlRegisteredServiceServiceProviderMetadataFacade> verifySamlAuthenticationRequest( final Pair<? extends SignableSAMLObject, MessageContext> authenticationContext, final HttpServletRequest request) throws Exception { final AuthnRequest authnRequest = AuthnRequest.class.cast(authenticationContext.getKey()); final String issuer = SamlIdPUtils.getIssuerFromSamlRequest(authnRequest); LOGGER.debug("Located issuer [{}] from authentication request", issuer); final SamlRegisteredService registeredService = verifySamlRegisteredService(issuer); LOGGER.debug("Fetching saml metadata adaptor for [{}]", issuer); final Optional<SamlRegisteredServiceServiceProviderMetadataFacade> adaptor = SamlRegisteredServiceServiceProviderMetadataFacade.get(this.samlRegisteredServiceCachingMetadataResolver, registeredService, authnRequest); if (!adaptor.isPresent()) { LOGGER.warn("No metadata could be found for [{}]", issuer); throw new UnauthorizedServiceException(UnauthorizedServiceException.CODE_UNAUTHZ_SERVICE, "Cannot find metadata linked to " + issuer); } verifyAuthenticationContextSignature(authenticationContext, request, authnRequest, adaptor.get()); SamlUtils.logSamlObject(this.configBean, authnRequest); return Pair.of(registeredService, adaptor.get()); } /** * Verify authentication context signature. * * @param authenticationContext the authentication context * @param request the request * @param authnRequest the authn request * @param adaptor the adaptor * @throws Exception the exception */ protected void verifyAuthenticationContextSignature(final Pair<? extends SignableSAMLObject, MessageContext> authenticationContext, final HttpServletRequest request, final AuthnRequest authnRequest, final SamlRegisteredServiceServiceProviderMetadataFacade adaptor) throws Exception { final MessageContext ctx = authenticationContext.getValue(); if (!SAMLBindingSupport.isMessageSigned(ctx)) { LOGGER.debug("The authentication context is not signed"); if (adaptor.isAuthnRequestsSigned()) { LOGGER.error("Metadata for [{}] says authentication requests are signed, yet authentication request is not", adaptor.getEntityId()); throw new SAMLException("AuthN request is not signed but should be"); } LOGGER.debug("Authentication request is not signed, so there is no need to verify its signature."); } else { LOGGER.debug("The authentication context is signed; Proceeding to validate signatures..."); this.samlObjectSignatureValidator.verifySamlProfileRequestIfNeeded(authnRequest, adaptor, request, ctx); } } /** * Build saml response. * * @param response the response * @param request the request * @param authenticationContext the authentication context * @param casAssertion the cas assertion * @param binding the binding */ protected void buildSamlResponse(final HttpServletResponse response, final HttpServletRequest request, final Pair<AuthnRequest, MessageContext> authenticationContext, final Assertion casAssertion, final String binding) { final String issuer = SamlIdPUtils.getIssuerFromSamlRequest(authenticationContext.getKey()); LOGGER.debug("Located issuer [{}] from authentication context", issuer); final SamlRegisteredService registeredService = verifySamlRegisteredService(issuer); LOGGER.debug("Located SAML metadata for [{}]", registeredService); final Optional<SamlRegisteredServiceServiceProviderMetadataFacade> adaptor = getSamlMetadataFacadeFor(registeredService, authenticationContext.getKey()); if (!adaptor.isPresent()) { throw new UnauthorizedServiceException(UnauthorizedServiceException.CODE_UNAUTHZ_SERVICE, "Cannot find metadata linked to " + issuer); } LOGGER.debug("Preparing SAML response for [{}]", adaptor.get().getEntityId()); final SamlRegisteredServiceServiceProviderMetadataFacade facade = adaptor.get(); final AuthnRequest authnRequest = authenticationContext.getKey(); this.responseBuilder.build(authnRequest, request, response, casAssertion, registeredService, facade, binding); LOGGER.info("Built the SAML response for [{}]", facade.getEntityId()); } /** * Handle unauthorized service exception. * * @param req the req * @param ex the ex * @return the model and view */ @ExceptionHandler(UnauthorizedServiceException.class) public ModelAndView handleUnauthorizedServiceException(final HttpServletRequest req, final Exception ex) { return WebUtils.produceUnauthorizedErrorView(); } }