package org.apereo.cas.support.oauth.validator; import org.apache.commons.lang3.BooleanUtils; import org.apache.commons.lang3.StringUtils; import org.apache.http.NameValuePair; import org.apache.http.client.utils.URIBuilder; import org.apereo.cas.authentication.AuthenticationServiceSelectionStrategy; import org.apereo.cas.authentication.principal.Service; import org.apereo.cas.authentication.principal.ServiceFactory; import org.apereo.cas.authentication.principal.WebApplicationService; import org.apereo.cas.services.RegisteredService; import org.apereo.cas.services.ServicesManager; import org.apereo.cas.support.oauth.OAuth20Constants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.Ordered; import java.util.Optional; /** * This is {@link OAuth20AuthenticationServiceSelectionStrategy}. * * @author Misagh Moayyed * @since 5.0.0 */ public class OAuth20AuthenticationServiceSelectionStrategy implements AuthenticationServiceSelectionStrategy { private static final long serialVersionUID = 8517547235465666978L; private static final Logger LOGGER = LoggerFactory.getLogger(OAuth20AuthenticationServiceSelectionStrategy.class); private final ServicesManager servicesManager; private final ServiceFactory<WebApplicationService> webApplicationServiceFactory; private final String callbackUrl; private int order = Ordered.HIGHEST_PRECEDENCE; public OAuth20AuthenticationServiceSelectionStrategy(final ServicesManager servicesManager, final ServiceFactory<WebApplicationService> webApplicationServiceFactory, final String callbackUrl) { this.servicesManager = servicesManager; this.webApplicationServiceFactory = webApplicationServiceFactory; this.callbackUrl = callbackUrl; } @Override public Service resolveServiceFrom(final Service service) { final Optional<NameValuePair> clientId = resolveClientIdFromService(service); final Optional<NameValuePair> redirectUri = resolveRedirectUri(service); if (clientId.isPresent() && redirectUri.isPresent()) { return this.webApplicationServiceFactory.createService(redirectUri.get().getValue()); } return service; } private static Optional<NameValuePair> resolveClientIdFromService(final Service service) { try { final URIBuilder builder = new URIBuilder(service.getId()); return builder.getQueryParams().stream().filter(p -> p.getName().equals(OAuth20Constants.CLIENT_ID)).findFirst(); } catch (final Exception e) { LOGGER.error(e.getMessage()); } return Optional.empty(); } private static Optional<NameValuePair> resolveRedirectUri(final Service service) { try { final URIBuilder builder = new URIBuilder(service.getId()); return builder.getQueryParams().stream().filter(p -> p.getName().equals(OAuth20Constants.REDIRECT_URI)).findFirst(); } catch (final Exception e) { LOGGER.error(e.getMessage()); } return Optional.empty(); } @Override public boolean supports(final Service service) { final RegisteredService svc = this.servicesManager.findServiceBy(service); final boolean res = svc != null && service.getId().startsWith(this.callbackUrl); LOGGER.debug("Authentication request is{}identified as an OAuth request", BooleanUtils.toString(res, StringUtils.EMPTY, " not ")); return res; } @Override public int getOrder() { return order; } }