package org.springframework.security.oauth2.client; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URISyntaxException; import java.net.URLEncoder; import java.util.Arrays; import org.springframework.http.HttpMethod; import org.springframework.http.client.ClientHttpRequest; import org.springframework.security.oauth2.client.http.AccessTokenRequiredException; import org.springframework.security.oauth2.client.http.OAuth2ErrorHandler; import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException; import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails; import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; import org.springframework.security.oauth2.client.token.AccessTokenProvider; import org.springframework.security.oauth2.client.token.AccessTokenProviderChain; import org.springframework.security.oauth2.client.token.AccessTokenRequest; import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider; import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider; import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitAccessTokenProvider; import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordAccessTokenProvider; import org.springframework.security.oauth2.common.AuthenticationScheme; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.exceptions.InvalidTokenException; import org.springframework.web.client.RequestCallback; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.ResponseExtractor; import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestTemplate; /** * Rest template that is able to make OAuth2-authenticated REST requests with the credentials of the provided resource. * * @author Ryan Heaton * @author Dave Syer */ public class OAuth2RestTemplate extends RestTemplate implements OAuth2RestOperations { private final OAuth2ProtectedResourceDetails resource; private AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList( new AuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(), new ClientCredentialsAccessTokenProvider())); private OAuth2ClientContext context; private boolean retryBadAccessTokens = true; private OAuth2RequestAuthenticator authenticator = new DefaultOAuth2RequestAuthenticator(); public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource) { this(resource, new DefaultOAuth2ClientContext()); } public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource, OAuth2ClientContext context) { super(); if (resource == null) { throw new IllegalArgumentException("An OAuth2 resource must be supplied."); } this.resource = resource; this.context = context; setErrorHandler(new OAuth2ErrorHandler(resource)); } /** * Strategy for extracting an Authorization header from an access token and the request details. Defaults to the * simple form "TOKEN_TYPE TOKEN_VALUE". * * @param authenticator the authenticator to use */ public void setAuthenticator(OAuth2RequestAuthenticator authenticator) { this.authenticator = authenticator; } /** * Flag to determine whether a request that has an existing access token, and which then leads to an * AccessTokenRequiredException should be retried (immediately, once). Useful if the remote server doesn't recognize * an old token which is stored in the client, but is happy to re-grant it. * * @param retryBadAccessTokens the flag to set (default true) */ public void setRetryBadAccessTokens(boolean retryBadAccessTokens) { this.retryBadAccessTokens = retryBadAccessTokens; } @Override public void setErrorHandler(ResponseErrorHandler errorHandler) { if (!(errorHandler instanceof OAuth2ErrorHandler)) { errorHandler = new OAuth2ErrorHandler(errorHandler, resource); } super.setErrorHandler(errorHandler); } @Override public OAuth2ProtectedResourceDetails getResource() { return resource; } @Override protected ClientHttpRequest createRequest(URI uri, HttpMethod method) throws IOException { OAuth2AccessToken accessToken = getAccessToken(); AuthenticationScheme authenticationScheme = resource.getAuthenticationScheme(); if (AuthenticationScheme.query.equals(authenticationScheme) || AuthenticationScheme.form.equals(authenticationScheme)) { uri = appendQueryParameter(uri, accessToken); } ClientHttpRequest req = super.createRequest(uri, method); if (AuthenticationScheme.header.equals(authenticationScheme)) { authenticator.authenticate(resource, getOAuth2ClientContext(), req); } return req; } @Override protected <T> T doExecute(URI url, HttpMethod method, RequestCallback requestCallback, ResponseExtractor<T> responseExtractor) throws RestClientException { OAuth2AccessToken accessToken = context.getAccessToken(); RuntimeException rethrow = null; try { return super.doExecute(url, method, requestCallback, responseExtractor); } catch (AccessTokenRequiredException e) { rethrow = e; } catch (OAuth2AccessDeniedException e) { rethrow = e; } catch (InvalidTokenException e) { // Don't reveal the token value in case it is logged rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId()); } if (accessToken != null && retryBadAccessTokens) { context.setAccessToken(null); try { return super.doExecute(url, method, requestCallback, responseExtractor); } catch (InvalidTokenException e) { // Don't reveal the token value in case it is logged rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId()); } } throw rethrow; } /** * @return the client id for this resource. */ private String getClientId() { return resource.getClientId(); } /** * Acquire or renew an access token for the current context if necessary. This method will be called automatically * when a request is executed (and the result is cached), but can also be called as a standalone method to * pre-populate the token. * * @return an access token */ public OAuth2AccessToken getAccessToken() throws UserRedirectRequiredException { OAuth2AccessToken accessToken = context.getAccessToken(); if (accessToken == null || accessToken.isExpired()) { try { accessToken = acquireAccessToken(context); } catch (UserRedirectRequiredException e) { context.setAccessToken(null); // No point hanging onto it now accessToken = null; String stateKey = e.getStateKey(); if (stateKey != null) { Object stateToPreserve = e.getStateToPreserve(); if (stateToPreserve == null) { stateToPreserve = "NONE"; } context.setPreservedState(stateKey, stateToPreserve); } throw e; } } return accessToken; } /** * @return the context for this template */ public OAuth2ClientContext getOAuth2ClientContext() { return context; } protected OAuth2AccessToken acquireAccessToken(OAuth2ClientContext oauth2Context) throws UserRedirectRequiredException { AccessTokenRequest accessTokenRequest = oauth2Context.getAccessTokenRequest(); if (accessTokenRequest == null) { throw new AccessTokenRequiredException( "No OAuth 2 security context has been established. Unable to access resource '" + this.resource.getId() + "'.", resource); } // Transfer the preserved state from the (longer lived) context to the current request. String stateKey = accessTokenRequest.getStateKey(); if (stateKey != null) { accessTokenRequest.setPreservedState(oauth2Context.removePreservedState(stateKey)); } OAuth2AccessToken existingToken = oauth2Context.getAccessToken(); if (existingToken != null) { accessTokenRequest.setExistingToken(existingToken); } OAuth2AccessToken accessToken = null; accessToken = accessTokenProvider.obtainAccessToken(resource, accessTokenRequest); if (accessToken == null || accessToken.getValue() == null) { throw new IllegalStateException( "Access token provider returned a null access token, which is illegal according to the contract."); } oauth2Context.setAccessToken(accessToken); return accessToken; } protected URI appendQueryParameter(URI uri, OAuth2AccessToken accessToken) { try { // TODO: there is some duplication with UriUtils here. Probably unavoidable as long as this // method signature uses URI not String. String query = uri.getRawQuery(); // Don't decode anything here String queryFragment = resource.getTokenName() + "=" + URLEncoder.encode(accessToken.getValue(), "UTF-8"); if (query == null) { query = queryFragment; } else { query = query + "&" + queryFragment; } // first form the URI without query and fragment parts, so that it doesn't re-encode some query string chars // (SECOAUTH-90) URI update = new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), uri.getPath(), null, null); // now add the encoded query string and the then fragment StringBuffer sb = new StringBuffer(update.toString()); sb.append("?"); sb.append(query); if (uri.getFragment() != null) { sb.append("#"); sb.append(uri.getFragment()); } return new URI(sb.toString()); } catch (URISyntaxException e) { throw new IllegalArgumentException("Could not parse URI", e); } catch (UnsupportedEncodingException e) { throw new IllegalArgumentException("Could not encode URI", e); } } public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) { this.accessTokenProvider = accessTokenProvider; } }