/* * Copyright 2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.social.security.provider; import java.util.HashSet; import java.util.Set; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.social.connect.Connection; import org.springframework.social.connect.support.OAuth2ConnectionFactory; import org.springframework.social.oauth2.AccessGrant; import org.springframework.social.oauth2.OAuth2Parameters; import org.springframework.social.security.SocialAuthenticationRedirectException; import org.springframework.social.security.SocialAuthenticationToken; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClientException; /** * @author Stefan Fussennegger * @param <S> The provider's API type. */ public class OAuth2AuthenticationService<S> extends AbstractSocialAuthenticationService<S> { protected final Log logger = LogFactory.getLog(getClass()); private OAuth2ConnectionFactory<S> connectionFactory; private Set<String> returnToUrlParameters; private String defaultScope = ""; public OAuth2AuthenticationService(OAuth2ConnectionFactory<S> connectionFactory) { setConnectionFactory(connectionFactory); } public OAuth2ConnectionFactory<S> getConnectionFactory() { return connectionFactory; } public void setConnectionFactory(OAuth2ConnectionFactory<S> connectionFactory) { this.connectionFactory = connectionFactory; } public void setReturnToUrlParameters(Set<String> returnToUrlParameters) { Assert.notNull(returnToUrlParameters, "returnToUrlParameters cannot be null"); this.returnToUrlParameters = returnToUrlParameters; } public Set<String> getReturnToUrlParameters() { if (returnToUrlParameters == null) { returnToUrlParameters = new HashSet<String>(); } return returnToUrlParameters; } /** * @param defaultScope OAuth scope to use, i.e. requested permissions */ public void setDefaultScope(String defaultScope) { this.defaultScope = defaultScope; } public void afterPropertiesSet() throws Exception { super.afterPropertiesSet(); Assert.notNull(getConnectionFactory(), "connectionFactory"); } public SocialAuthenticationToken getAuthToken(HttpServletRequest request, HttpServletResponse response) throws SocialAuthenticationRedirectException { String code = request.getParameter("code"); if (!StringUtils.hasText(code)) { OAuth2Parameters params = new OAuth2Parameters(); params.setRedirectUri(buildReturnToUrl(request)); setScope(request, params); params.add("state", generateState(connectionFactory, request)); addCustomParameters(params); throw new SocialAuthenticationRedirectException(getConnectionFactory().getOAuthOperations().buildAuthenticateUrl(params)); } else if (StringUtils.hasText(code)) { try { String returnToUrl = buildReturnToUrl(request); AccessGrant accessGrant = getConnectionFactory().getOAuthOperations().exchangeForAccess(code, returnToUrl, null); // TODO avoid API call if possible (auth using token would be fine) Connection<S> connection = getConnectionFactory().createConnection(accessGrant); return new SocialAuthenticationToken(connection, null); } catch (RestClientException e) { logger.debug("failed to exchange for access", e); return null; } } else { return null; } } private String generateState(OAuth2ConnectionFactory<?> connectionFactory, HttpServletRequest request) { final String state = request.getParameter("state"); return (state != null) ? state : connectionFactory.generateState(); } protected String buildReturnToUrl(HttpServletRequest request) { StringBuffer sb = getProxyHeaderAwareRequestURL(request); sb.append("?"); for (String name : getReturnToUrlParameters()) { // Assume for simplicity that there is only one value String value = request.getParameter(name); if (value == null) { continue; } sb.append(name).append("=").append(value).append("&"); } sb.setLength(sb.length() - 1); // strip trailing ? or & return sb.toString(); } protected StringBuffer getProxyHeaderAwareRequestURL(HttpServletRequest request) { String host = request.getHeader("Host"); if (StringUtils.isEmpty(host)) { return request.getRequestURL(); } StringBuffer sb = new StringBuffer(); String schemeHeader = request.getHeader("X-Forwarded-Proto"); String portHeader = request.getHeader("X-Forwarded-Port"); String scheme = StringUtils.isEmpty(schemeHeader) ? "http" : schemeHeader; String port = StringUtils.isEmpty(portHeader) ? "80" : portHeader; if (scheme.equals("http") && port.equals("80")){ port = ""; } if (scheme.equals("https") && port.equals("443")){ port = ""; } sb.append(scheme); sb.append("://"); sb.append(host); if (StringUtils.hasLength(port)){ sb.append(":"); sb.append(port); } sb.append(request.getRequestURI()); return sb; } private void setScope(HttpServletRequest request, OAuth2Parameters params) { String requestedScope = request.getParameter("scope"); if (StringUtils.hasLength(requestedScope)) { params.setScope(requestedScope); } else if (StringUtils.hasLength(defaultScope)) { params.setScope(defaultScope); } } protected void addCustomParameters(OAuth2Parameters params) { } }