package org.springframework.security.oauth2.client;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.Date;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
import org.springframework.security.oauth2.client.resource.BaseOAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
import org.springframework.security.oauth2.client.token.AccessTokenProvider;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.util.UriTemplate;
/**
* @author Ryan Heaton
* @author Dave Syer
*/
public class OAuth2RestTemplateTests {
private BaseOAuth2ProtectedResourceDetails resource;
private OAuth2RestTemplate restTemplate;
private AccessTokenProvider accessTokenProvider = Mockito.mock(AccessTokenProvider.class);
private ClientHttpRequest request;
private HttpHeaders headers;
@Before
public void open() throws Exception {
resource = new BaseOAuth2ProtectedResourceDetails();
// Facebook and older specs:
resource.setTokenName("bearer_token");
restTemplate = new OAuth2RestTemplate(resource);
restTemplate.setAccessTokenProvider(accessTokenProvider);
request = Mockito.mock(ClientHttpRequest.class);
headers = new HttpHeaders();
Mockito.when(request.getHeaders()).thenReturn(headers);
ClientHttpResponse response = Mockito.mock(ClientHttpResponse.class);
HttpStatus statusCode = HttpStatus.OK;
Mockito.when(response.getStatusCode()).thenReturn(statusCode);
Mockito.when(request.execute()).thenReturn(response);
}
@Test
public void testNonBearerToken() throws Exception {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken("12345");
token.setTokenType("MINE");
restTemplate.getOAuth2ClientContext().setAccessToken(token);
ClientHttpRequest http = restTemplate.createRequest(URI.create("https://nowhere.com/api/crap"), HttpMethod.GET);
String auth = http.getHeaders().getFirst("Authorization");
assertTrue(auth.startsWith("MINE "));
}
@Test
public void testCustomAuthenticator() throws Exception {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken("12345");
token.setTokenType("MINE");
restTemplate.setAuthenticator(new OAuth2RequestAuthenticator() {
@Override
public void authenticate(OAuth2ProtectedResourceDetails resource, OAuth2ClientContext clientContext, ClientHttpRequest req) {
req.getHeaders().set("X-Authorization", clientContext.getAccessToken().getTokenType() + " " + "Nah-nah-na-nah-nah");
}
});
restTemplate.getOAuth2ClientContext().setAccessToken(token);
ClientHttpRequest http = restTemplate.createRequest(URI.create("https://nowhere.com/api/crap"), HttpMethod.GET);
String auth = http.getHeaders().getFirst("X-Authorization");
assertEquals("MINE Nah-nah-na-nah-nah", auth);
}
/**
* tests appendQueryParameter
*/
@Test
public void testAppendQueryParameter() throws Exception {
OAuth2AccessToken token = new DefaultOAuth2AccessToken("12345");
URI appended = restTemplate.appendQueryParameter(URI.create("https://graph.facebook.com/search?type=checkin"),
token);
assertEquals("https://graph.facebook.com/search?type=checkin&bearer_token=12345", appended.toString());
}
/**
* tests appendQueryParameter
*/
@Test
public void testAppendQueryParameterWithNoExistingParameters() throws Exception {
OAuth2AccessToken token = new DefaultOAuth2AccessToken("12345");
URI appended = restTemplate.appendQueryParameter(URI.create("https://graph.facebook.com/search"), token);
assertEquals("https://graph.facebook.com/search?bearer_token=12345", appended.toString());
}
/**
* tests encoding of access token value
*/
@Test
public void testDoubleEncodingOfParameterValue() throws Exception {
OAuth2AccessToken token = new DefaultOAuth2AccessToken("1/qIxxx");
URI appended = restTemplate.appendQueryParameter(URI.create("https://graph.facebook.com/search"), token);
assertEquals("https://graph.facebook.com/search?bearer_token=1%2FqIxxx", appended.toString());
}
/**
* tests no double encoding of existing query parameter
*/
@Test
public void testNonEncodingOfUriTemplate() throws Exception {
OAuth2AccessToken token = new DefaultOAuth2AccessToken("12345");
UriTemplate uriTemplate = new UriTemplate("https://graph.facebook.com/fql?q={q}");
URI expanded = uriTemplate.expand("[q: fql]");
URI appended = restTemplate.appendQueryParameter(expanded, token);
assertEquals("https://graph.facebook.com/fql?q=%5Bq:%20fql%5D&bearer_token=12345", appended.toString());
}
/**
* tests URI with fragment value
*/
@Test
public void testFragmentUri() throws Exception {
OAuth2AccessToken token = new DefaultOAuth2AccessToken("1234");
URI appended = restTemplate.appendQueryParameter(URI.create("https://graph.facebook.com/search#foo"), token);
assertEquals("https://graph.facebook.com/search?bearer_token=1234#foo", appended.toString());
}
/**
* tests encoding of access token value passed in protected requests ref: SECOAUTH-90
*/
@Test
public void testDoubleEncodingOfAccessTokenValue() throws Exception {
// try with fictitious token value with many characters to encode
OAuth2AccessToken token = new DefaultOAuth2AccessToken("1 qI+x:y=z");
// System.err.println(UriUtils.encodeQueryParam(token.getValue(), "UTF-8"));
URI appended = restTemplate.appendQueryParameter(URI.create("https://graph.facebook.com/search"), token);
assertEquals("https://graph.facebook.com/search?bearer_token=1+qI%2Bx%3Ay%3Dz", appended.toString());
}
@Test(expected = AccessTokenRequiredException.class)
public void testNoRetryAccessDeniedExceptionForNoExistingToken() throws Exception {
restTemplate.setAccessTokenProvider(new StubAccessTokenProvider());
restTemplate.setRequestFactory(new ClientHttpRequestFactory() {
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
throw new AccessTokenRequiredException(resource);
}
});
restTemplate.doExecute(new URI("http://foo"), HttpMethod.GET, new NullRequestCallback(),
new SimpleResponseExtractor());
}
@Test
public void testRetryAccessDeniedException() throws Exception {
final AtomicBoolean failed = new AtomicBoolean(false);
restTemplate.getOAuth2ClientContext().setAccessToken(new DefaultOAuth2AccessToken("TEST"));
restTemplate.setAccessTokenProvider(new StubAccessTokenProvider());
restTemplate.setRequestFactory(new ClientHttpRequestFactory() {
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
if (!failed.get()) {
failed.set(true);
throw new AccessTokenRequiredException(resource);
}
return request;
}
});
Boolean result = restTemplate.doExecute(new URI("http://foo"), HttpMethod.GET, new NullRequestCallback(),
new SimpleResponseExtractor());
assertTrue(result);
}
@Test
public void testNewTokenAcquiredIfExpired() throws Exception {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken("TEST");
token.setExpiration(new Date(System.currentTimeMillis() - 1000));
restTemplate.getOAuth2ClientContext().setAccessToken(token);
restTemplate.setAccessTokenProvider(new StubAccessTokenProvider());
OAuth2AccessToken newToken = restTemplate.getAccessToken();
assertNotNull(newToken);
assertTrue(!token.equals(newToken));
}
@Test
public void testTokenIsResetIfInvalid() throws Exception {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken("TEST");
token.setExpiration(new Date(System.currentTimeMillis() - 1000));
restTemplate.getOAuth2ClientContext().setAccessToken(token);
restTemplate.setAccessTokenProvider(new StubAccessTokenProvider() {
@Override
public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details,
AccessTokenRequest parameters) throws UserRedirectRequiredException, AccessDeniedException {
throw new UserRedirectRequiredException("http://foo.com", Collections.<String, String> emptyMap());
}
});
try {
OAuth2AccessToken newToken = restTemplate.getAccessToken();
assertNotNull(newToken);
fail("Expected UserRedirectRequiredException");
}
catch (UserRedirectRequiredException e) {
// planned
}
// context token should be reset as it clearly is invalid at this point
assertNull(restTemplate.getOAuth2ClientContext().getAccessToken());
}
private final class SimpleResponseExtractor implements ResponseExtractor<Boolean> {
public Boolean extractData(ClientHttpResponse response) throws IOException {
return true;
}
}
private static class NullRequestCallback implements RequestCallback {
public void doWithRequest(ClientHttpRequest request) throws IOException {
}
}
private static class StubAccessTokenProvider implements AccessTokenProvider {
public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest parameters)
throws UserRedirectRequiredException, AccessDeniedException {
return new DefaultOAuth2AccessToken("FOO");
}
public boolean supportsRefresh(OAuth2ProtectedResourceDetails resource) {
return false;
}
public OAuth2AccessToken refreshAccessToken(OAuth2ProtectedResourceDetails resource,
OAuth2RefreshToken refreshToken, AccessTokenRequest request) throws UserRedirectRequiredException {
return null;
}
public boolean supportsResource(OAuth2ProtectedResourceDetails resource) {
return true;
}
}
}