/* * Copyright 2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.social.connect.web; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.social.ApiException; import org.springframework.social.InsufficientPermissionException; import org.springframework.social.NotAuthorizedException; import org.springframework.social.OperationNotPermittedException; import org.springframework.social.UserIdSource; import org.springframework.social.connect.ConnectionFactory; import org.springframework.social.connect.ConnectionRepository; import org.springframework.social.connect.UsersConnectionRepository; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; /** * <p>Servlet filter that intercepts Spring Social {@link ApiException}s thrown in the course of a request and attempts to reconcile any connection-related * problems by deleting the stale/revoked connection and walking the user through the connection process to obtain a new connection.</p> * * <p>This filter handles the exceptions via the following flow:</p> * * <ul> * <li>If an exception is thrown, redirects to /connect/{provider ID}?reconnect=true</li> * <li>Handles its own redirect to /connect/{provider ID}?reconnect=true and converts the request to a POST request to {@link ConnectController} to kick of the authorization flow.</li> * </ul> * * @since 1.1.0 * * @author Craig Walls */ public class ReconnectFilter extends GenericFilterBean { private final static Log logger = LogFactory.getLog(ReconnectFilter.class); private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer(); private UsersConnectionRepository usersConnectionRepository; private UserIdSource userIdSource; /** * Creates an instance of {@link ReconnectFilter}. * @param usersConnectionRepository a {@link UsersConnectionRepository} used to create a {@link ConnectionRepository} for the current user. * @param userIdSource an instance of {@link UserIdSource} to obtain the current user's ID used to create a {@link ConnectionFactory}. */ public ReconnectFilter(UsersConnectionRepository usersConnectionRepository, UserIdSource userIdSource) { Assert.notNull(usersConnectionRepository, "UsersConnectionRepository cannot be null"); Assert.notNull(userIdSource, "UserIdSource cannot be null"); this.usersConnectionRepository = usersConnectionRepository; this.userIdSource = userIdSource; } public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest httpRequest = (HttpServletRequest) request; HttpServletResponse httpResponse = (HttpServletResponse) response; if (shouldPerformRefreshPostRequest(httpRequest)) { if (logger.isDebugEnabled()) { logger.debug("Removing stale/revoked connection."); } String providerId = getProviderIdFromRequest(httpRequest); String currentUserId = userIdSource.getUserId(); usersConnectionRepository.createConnectionRepository(currentUserId).removeConnections(providerId); if (logger.isDebugEnabled()) { logger.debug("Initiating refresh request."); } HttpServletRequest newRequest = new ReconnectionPostRequest(httpRequest); chain.doFilter(newRequest, httpResponse); } else { // Pass request through filter chain and handle any exceptions that come out of it. try { if (logger.isDebugEnabled()) { logger.debug("Processing request"); } chain.doFilter(httpRequest, httpResponse); } catch (IOException e) { if (logger.isDebugEnabled()) { logger.debug("IOException: " + e.getMessage()); } throw e; } catch (Exception e) { handleExceptionFromFilterChain(e, httpRequest, httpResponse); } } } // subclassing hooks /** * Returns the URL to redirect to if it is determined that a connection needs to be renewed. * By default, the filter will redirect to /connect/{provider ID} with a "reconnect" query parameter. * This filter also handles GET requests to that same path before submitting a POST request to {@link ConnectController} for authorization. * May be overridden by a subclass to handle other flows, such as redirecting to a page that informs the user that a new connection is needed. * @param request The HTTP request that triggered the exception. * @param apiException The {@link ApiException}. * @return the URL to redirect to if a connection needs to be renewed. */ protected String getRefreshUrl(HttpServletRequest request, ApiException apiException) { String scopeNeeded = getRequiredScope(apiException); StringBuilder sb = new StringBuilder(request.getContextPath() + CONNECT_PATH + apiException.getProviderId()) .append(RECONNECT_PARAMETER_EQUALS_TRUE); if (scopeNeeded != null) { sb.append(SCOPE_PARAMETER_EQUALS + scopeNeeded); } return sb.toString(); } /** * Determines whether or not the handled request should be converted to a POST request to {@link ConnectController} for authorization. * By default, will return true if the request is a GET request for /connect/{provider ID} and there is a "reconnect" query parameter. * May be overridden by a subclass to consider other criteria in deciding whether or not to convert the request. * @param request the handled request. * @return true if the request should be converted to a POST request to {@link ConnectController}. */ protected boolean shouldPerformRefreshPostRequest(HttpServletRequest request) { String servletPath = request.getServletPath(); return request.getMethod().equalsIgnoreCase(GET) && servletPath != null && servletPath.startsWith(CONNECT_PATH) && request.getParameter(RECONNECT_PARAMETER) != null; } // private helpers private String getRequiredScope(ApiException apiException) { return apiException instanceof InsufficientPermissionException ? ((InsufficientPermissionException) apiException).getRequiredPermission() : null; } private String getProviderIdFromRequest(HttpServletRequest httpRequest) { return httpRequest.getServletPath().substring(CONNECT_PATH_LENGTH).replace("/", ""); } private void handleExceptionFromFilterChain(Exception e, HttpServletRequest httpRequest, HttpServletResponse httpResponse) throws IOException, ServletException { RuntimeException ase = (ApiException) throwableAnalyzer.getFirstThrowableOfType(ApiException.class, throwableAnalyzer.determineCauseChain(e)); if (ase != null && ase instanceof ApiException) { ApiException apiException = (ApiException) ase; if (logger.isDebugEnabled()) { logger.debug("API Exception: " + e.getMessage()); } if (apiException instanceof NotAuthorizedException || apiException instanceof OperationNotPermittedException) { if (logger.isDebugEnabled()) { logger.debug("Redirecting for refresh of " + apiException.getProviderId() + " connection."); } httpResponse.sendRedirect(getRefreshUrl(httpRequest, apiException)); return; } } if (e instanceof ServletException) { throw (ServletException) e; } else if (e instanceof RuntimeException) { throw (RuntimeException) e; } // Wrap other Exceptions in a generic RuntimeException. This should never happen because // we've already covered all the possibilities for doFilter throw new RuntimeException(e); } /* * Request wrapper that converts existing request into a POST request to ConnectController */ private final class ReconnectionPostRequest extends HttpServletRequestWrapper { private ReconnectionPostRequest(HttpServletRequest request) { super(request); } @Override public String getMethod() { return POST; } } private static final String CONNECT_PATH = "/connect/"; private static final int CONNECT_PATH_LENGTH = CONNECT_PATH.length(); private static final String RECONNECT_PARAMETER = "reconnect"; private static final String RECONNECT_PARAMETER_EQUALS_TRUE = "?" + RECONNECT_PARAMETER +"=true"; private static final String SCOPE_PARAMETER_EQUALS = "&scope="; private static final String POST = "POST"; private static final String GET = "GET"; }