package org.springframework.security.oauth.consumer.client;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.security.oauth.consumer.OAuthConsumerSupport;
import org.springframework.security.oauth.consumer.OAuthConsumerToken;
import org.springframework.security.oauth.consumer.OAuthSecurityContext;
import org.springframework.security.oauth.consumer.OAuthSecurityContextHolder;
import org.springframework.security.oauth.consumer.OAuthSecurityContextImpl;
import org.springframework.security.oauth.consumer.ProtectedResourceDetails;
/**
* Request factory that extends all http requests with the OAuth credentials for a specific protected resource.
*
* @author Ryan Heaton
*/
public class OAuthClientHttpRequestFactory implements ClientHttpRequestFactory {
private final ClientHttpRequestFactory delegate;
private final ProtectedResourceDetails resource;
private final OAuthConsumerSupport support;
private Map<String, String> additionalOAuthParameters;
public OAuthClientHttpRequestFactory(ClientHttpRequestFactory delegate, ProtectedResourceDetails resource, OAuthConsumerSupport support) {
this.delegate = delegate;
this.resource = resource;
this.support = support;
if (delegate == null) {
throw new IllegalArgumentException("A delegate must be supplied for an OAuth2ClientHttpRequestFactory.");
}
if (resource == null) {
throw new IllegalArgumentException("A resource must be supplied for an OAuth2ClientHttpRequestFactory.");
}
}
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
OAuthSecurityContext context = OAuthSecurityContextHolder.getContext();
if (context == null) {
context = new OAuthSecurityContextImpl();
}
Map<String, OAuthConsumerToken> accessTokens = context.getAccessTokens();
OAuthConsumerToken accessToken = accessTokens == null ? null : accessTokens.get(this.resource.getId());
boolean useAuthHeader = this.resource.isAcceptsAuthorizationHeader();
if (!useAuthHeader) {
String queryString = this.support.getOAuthQueryString(this.resource, accessToken, uri.toURL(), httpMethod.name(), this.additionalOAuthParameters);
String uriValue = String.valueOf(uri);
uri = URI.create((uriValue.contains("?") ? uriValue.substring(0, uriValue.indexOf('?')) : uriValue) + "?" + queryString);
}
ClientHttpRequest req = delegate.createRequest(uri, httpMethod);
if (useAuthHeader) {
String authHeader = this.support.getAuthorizationHeader(this.resource, accessToken, uri.toURL(), httpMethod.name(), this.additionalOAuthParameters);
req.getHeaders().add("Authorization", authHeader);
}
Map<String, String> additionalHeaders = this.resource.getAdditionalRequestHeaders();
if (additionalHeaders != null) {
for (Map.Entry<String, String> header : additionalHeaders.entrySet()) {
req.getHeaders().add(header.getKey(), header.getValue());
}
}
return req;
}
/**
* Any additional OAuth parameters to send with the OAuth request.
*
* @return Any additional OAuth parameters to send with the OAuth request.
*/
public Map<String, String> getAdditionalOAuthParameters() {
return additionalOAuthParameters;
}
/**
* Any additional OAuth parameters to send with the OAuth request.
*
* @param additionalOAuthParameters Any additional OAuth parameters to send with the OAuth request.
*/
public void setAdditionalOAuthParameters(Map<String, String> additionalOAuthParameters) {
this.additionalOAuthParameters = additionalOAuthParameters;
}
}