package org.apereo.cas.support.oauth.util; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.http.HttpStatus; import org.apereo.cas.services.RegisteredService; import org.apereo.cas.services.ServicesManager; import org.apereo.cas.services.UnauthorizedServiceException; import org.apereo.cas.support.oauth.OAuth20GrantTypes; import org.apereo.cas.support.oauth.OAuth20ResponseTypes; import org.apereo.cas.support.oauth.OAuth20Constants; import org.apereo.cas.support.oauth.services.OAuthRegisteredService; import org.pac4j.core.context.J2EContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.view.RedirectView; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.PrintWriter; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.stream.Collectors; import static org.apereo.cas.support.oauth.OAuth20Constants.BASE_OAUTH20_URL; /** * This class has some usefull methods to output data in plain text, * handle redirects, add parameter in url or find the right provider. * * @author Jerome Leleu * @since 3.5.0 */ public final class OAuth20Utils { private static final Logger LOGGER = LoggerFactory.getLogger(OAuth20Utils.class); private static final ObjectWriter WRITER = new ObjectMapper().findAndRegisterModules().writer().withDefaultPrettyPrinter(); private OAuth20Utils() { } /** * Write to the output this error text and return a null view. * * @param response http response * @param error error message * @return a null view */ public static ModelAndView writeTextError(final HttpServletResponse response, final String error) { return OAuth20Utils.writeText(response, OAuth20Constants.ERROR + '=' + error, HttpStatus.SC_BAD_REQUEST); } /** * Write to the output the text and return a null view. * * @param response http response * @param text output text * @param status status code * @return a null view */ public static ModelAndView writeText(final HttpServletResponse response, final String text, final int status) { try (PrintWriter printWriter = response.getWriter()) { response.setStatus(status); printWriter.print(text); } catch (final IOException e) { LOGGER.error("Failed to write to response", e); } return null; } /** * Return a view which is a redirection to an url. * * @param url redirect url * @return A view which is a redirection to an url */ public static ModelAndView redirectTo(final String url) { return new ModelAndView(new RedirectView(url)); } /** * Locate the requested instance of {@link OAuthRegisteredService} by the given clientId. * * @param servicesManager the service registry DAO instance. * @param clientId the client id by which the {@link OAuthRegisteredService} is to be located. * @return null, or the located {@link OAuthRegisteredService} instance in the service registry. */ public static OAuthRegisteredService getRegisteredOAuthService(final ServicesManager servicesManager, final String clientId) { final Collection<RegisteredService> services = servicesManager.getAllServices(); return (OAuthRegisteredService) services.stream() .filter(OAuthRegisteredService.class::isInstance) .filter(s -> OAuthRegisteredService.class.cast(s).getClientId().equals(clientId)) .findFirst() .orElse(null); } /** * Gets attributes. * * @param attributes the attributes * @param context the context * @return the attributes */ public static Map<String, Object> getRequestParameters(final Collection<String> attributes, final HttpServletRequest context) { return attributes.stream() .filter(a -> StringUtils.isNotBlank(context.getParameter(a))) .map(m -> { final String[] values = context.getParameterValues(m); final Collection<String> valuesSet = new LinkedHashSet<>(); if (values != null && values.length > 0) { Arrays.stream(values).forEach(v -> valuesSet.addAll(Arrays.stream(v.split(" ")).collect(Collectors.toSet()))); } return Pair.of(m, valuesSet); }) .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } /** * Gets requested scopes. * * @param context the context * @return the requested scopes */ public static Collection<String> getRequestedScopes(final J2EContext context) { return getRequestedScopes(context.getRequest()); } /** * Gets requested scopes. * * @param context the context * @return the requested scopes */ public static Collection<String> getRequestedScopes(final HttpServletRequest context) { final Map<String, Object> map = getRequestParameters(Arrays.asList(OAuth20Constants.SCOPE), context); if (map == null || map.isEmpty()) { return Collections.emptyList(); } return (Collection<String>) map.get(OAuth20Constants.SCOPE); } /** * Produce unauthorized error view model and view. * * @return the model and view */ public static ModelAndView produceUnauthorizedErrorView() { return produceErrorView(new UnauthorizedServiceException(UnauthorizedServiceException.CODE_UNAUTHZ_SERVICE, StringUtils.EMPTY)); } /** * Produce error view model and view. * * @param e the e * @return the model and view */ public static ModelAndView produceErrorView(final Exception e) { final Map model = new HashMap<>(); model.put("rootCauseException", e); return new ModelAndView(OAuth20Constants.ERROR_VIEW, model); } /** * Cas oauth callback url. * * @param serverPrefixUrl the server prefix url * @return the string */ public static String casOAuthCallbackUrl(final String serverPrefixUrl) { return serverPrefixUrl.concat(BASE_OAUTH20_URL + '/' + OAuth20Constants.CALLBACK_AUTHORIZE_URL); } /** * Jsonify string. * * @param map the map * @return the string */ public static String jsonify(final Map map) { try { return WRITER.writeValueAsString(map); } catch (final Exception e) { throw new IllegalArgumentException(e.getMessage(), e); } } /** * Check the grant type against an expected grant type. * * @param type the given grant type * @param expectedType the expected grant type * @return whether the grant type is the expected one */ public static boolean isGrantType(final String type, final OAuth20GrantTypes expectedType) { return expectedType.name().equalsIgnoreCase(type); } /** * Check the response type against an expected response type. * * @param type the given response type * @param expectedType the expected response type * @return whether the response type is the expected one */ public static boolean isResponseType(final String type, final OAuth20ResponseTypes expectedType) { return expectedType.getType().equalsIgnoreCase(type); } }