package sparklr.common;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.util.Map;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.crypto.codec.Base64;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author Dave Syer
*/
public abstract class AbstractRefreshTokenSupportTests extends AbstractIntegrationTests {
/**
* tests a happy-day flow of the refresh token provider.
*/
@Test
public void testHappyDay() throws Exception {
OAuth2AccessToken accessToken = getAccessToken("read write", "my-trusted-client");
// now use the refresh token to get a new access token.
assertNotNull(accessToken.getRefreshToken());
OAuth2AccessToken newAccessToken = refreshAccessToken(accessToken.getRefreshToken().getValue());
assertFalse(newAccessToken.getValue().equals(accessToken.getValue()));
verifyAccessTokens(accessToken, newAccessToken);
}
protected void verifyAccessTokens(OAuth2AccessToken oldAccessToken, OAuth2AccessToken newAccessToken) {
// make sure the new access token can be used.
verifyTokenResponse(newAccessToken.getValue(), HttpStatus.OK);
// make sure the old access token isn't valid anymore.
verifyTokenResponse(oldAccessToken.getValue(), HttpStatus.UNAUTHORIZED);
}
protected void verifyTokenResponse(String accessToken, HttpStatus status) {
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, accessToken));
assertEquals(status, http.getStatusCode("/admin/beans", headers));
}
private OAuth2AccessToken refreshAccessToken(String refreshToken) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<String, String>();
formData.add("grant_type", "refresh_token");
formData.add("client_id", "my-trusted-client");
formData.add("refresh_token", refreshToken);
formData.add("scope", "read");
HttpHeaders headers = getTokenHeaders("my-trusted-client");
@SuppressWarnings("rawtypes")
ResponseEntity<Map> response = http.postForMap(tokenPath(), headers, formData);
assertEquals(HttpStatus.OK, response.getStatusCode());
assertTrue("Wrong cache control: " + response.getHeaders().getFirst("Cache-Control"), response.getHeaders()
.getFirst("Cache-Control").contains("no-store"));
@SuppressWarnings("unchecked")
OAuth2AccessToken newAccessToken = DefaultOAuth2AccessToken.valueOf(response.getBody());
return newAccessToken;
}
private OAuth2AccessToken getAccessToken(String scope, String clientId) throws Exception {
MultiValueMap<String, String> formData = getTokenFormData(scope, clientId);
HttpHeaders headers = getTokenHeaders(clientId);
@SuppressWarnings("rawtypes")
ResponseEntity<Map> response = http.postForMap(tokenPath(), headers, formData);
assertEquals(HttpStatus.OK, response.getStatusCode());
assertTrue("Wrong cache control: " + response.getHeaders().getFirst("Cache-Control"), response.getHeaders()
.getFirst("Cache-Control").contains("no-store"));
@SuppressWarnings("unchecked")
OAuth2AccessToken accessToken = DefaultOAuth2AccessToken.valueOf(response.getBody());
return accessToken;
}
private HttpHeaders getTokenHeaders(String clientId) {
HttpHeaders headers = new HttpHeaders();
if (clientId != null) {
headers.set("Authorization", "Basic " + new String(Base64.encode((clientId + ":").getBytes())));
}
return headers ;
}
private MultiValueMap<String, String> getTokenFormData(String scope, String clientId) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<String, String>();
formData.add("grant_type", "password");
if (clientId != null) {
formData.add("client_id", clientId);
}
formData.add("scope", scope);
formData.add("username", "user");
formData.add("password", "password");
return formData;
}
}