/*
* Copyright 2013-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package org.springframework.security.oauth2.provider.token;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.LinkedHashSet;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.oauth2.common.DefaultExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.config.annotation.builders.InMemoryClientDetailsServiceBuilder;
import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.oauth2.provider.ClientRegistrationException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.RequestTokenFactory;
import org.springframework.security.oauth2.provider.TokenRequest;
import org.springframework.security.oauth2.provider.client.BaseClientDetails;
/**
* @author Dave Syer
*
*/
public abstract class AbstractDefaultTokenServicesTests {
private DefaultTokenServices services;
private TokenStore tokenStore;
@Before
public void setUp() throws Exception {
tokenStore = createTokenStore();
services = new DefaultTokenServices();
configureTokenServices(services);
}
@Test
public void testClientSpecificRefreshTokenExpiry() throws Exception {
getTokenServices().setRefreshTokenValiditySeconds(1000);
getTokenServices().setClientDetailsService(new ClientDetailsService() {
public ClientDetails loadClientByClientId(String clientId) throws OAuth2Exception {
BaseClientDetails client = new BaseClientDetails();
client.setRefreshTokenValiditySeconds(100);
client.setAuthorizedGrantTypes(Arrays.asList("authorization_code", "refresh_token"));
return client;
}
});
OAuth2AccessToken accessToken = getTokenServices().createAccessToken(createAuthentication());
DefaultExpiringOAuth2RefreshToken refreshToken = (DefaultExpiringOAuth2RefreshToken) accessToken
.getRefreshToken();
Date expectedExpiryDate = new Date(System.currentTimeMillis() + 102 * 1000L);
assertTrue(expectedExpiryDate.after(refreshToken.getExpiration()));
}
@Test(expected = InvalidTokenException.class)
public void testClientInvalidated() throws Exception {
final AtomicBoolean deleted = new AtomicBoolean();
getTokenServices().setClientDetailsService(new ClientDetailsService() {
public ClientDetails loadClientByClientId(String clientId) throws OAuth2Exception {
if (deleted.get()) {
throw new ClientRegistrationException("No such client: " + clientId);
}
BaseClientDetails client = new BaseClientDetails();
client.setRefreshTokenValiditySeconds(100);
client.setAuthorizedGrantTypes(Arrays.asList("authorization_code", "refresh_token"));
return client;
}
});
OAuth2AccessToken token = getTokenServices().createAccessToken(createAuthentication());
deleted.set(true);
OAuth2Authentication authentication = getTokenServices().loadAuthentication(token.getValue());
assertNotNull(authentication.getOAuth2Request());
}
@Test(expected = InvalidGrantException.class)
public void testRefreshedTokenInvalidWithWrongClient() throws Exception {
ExpiringOAuth2RefreshToken expectedExpiringRefreshToken = (ExpiringOAuth2RefreshToken) getTokenServices()
.createAccessToken(createAuthentication()).getRefreshToken();
TokenRequest tokenRequest = new TokenRequest(Collections.singletonMap("client_id", "wrong"), "wrong", null,
null);
OAuth2AccessToken refreshedAccessToken = getTokenServices()
.refreshAccessToken(expectedExpiringRefreshToken.getValue(), tokenRequest);
assertEquals("[read]", refreshedAccessToken.getScope().toString());
}
@Test
public void testRefreshedTokenHasNarrowedScopes() throws Exception {
ExpiringOAuth2RefreshToken expectedExpiringRefreshToken = (ExpiringOAuth2RefreshToken) getTokenServices()
.createAccessToken(createAuthentication()).getRefreshToken();
TokenRequest tokenRequest = new TokenRequest(Collections.singletonMap("client_id", "id"), "id",
Collections.singleton("read"), null);
OAuth2AccessToken refreshedAccessToken = getTokenServices()
.refreshAccessToken(expectedExpiringRefreshToken.getValue(), tokenRequest);
assertEquals("[read]", refreshedAccessToken.getScope().toString());
}
@Test
public void testRefreshTokenRequestHasRefreshFlag() throws Exception {
ExpiringOAuth2RefreshToken expectedExpiringRefreshToken = (ExpiringOAuth2RefreshToken) getTokenServices()
.createAccessToken(createAuthentication()).getRefreshToken();
TokenRequest tokenRequest = new TokenRequest(Collections.singletonMap("client_id", "id"), "id",
Collections.singleton("read"), null);
final AtomicBoolean called = new AtomicBoolean(false);
getTokenServices().setTokenEnhancer(new TokenEnhancer() {
@Override
public OAuth2AccessToken enhance(OAuth2AccessToken accessToken, OAuth2Authentication authentication) {
assertTrue(authentication.getOAuth2Request().isRefresh());
called.set(true);
return accessToken;
}
});
getTokenServices().refreshAccessToken(expectedExpiringRefreshToken.getValue(), tokenRequest);
assertTrue(called.get());
}
@Test
public void testRefreshTokenNonExpiring() throws Exception {
ClientDetailsService clientDetailsService = new InMemoryClientDetailsServiceBuilder().withClient("id")
.refreshTokenValiditySeconds(0).authorizedGrantTypes("refresh_token").and().build();
DefaultTokenServices tokenServices = getTokenServices();
tokenServices.setClientDetailsService(clientDetailsService);
OAuth2RefreshToken refreshToken = tokenServices.createAccessToken(createAuthentication())
.getRefreshToken();
assertNotNull(refreshToken);
assertFalse(refreshToken instanceof ExpiringOAuth2RefreshToken);
}
@Test
public void testTokenRevoked() throws Exception {
OAuth2Authentication authentication = createAuthentication();
OAuth2AccessToken original = getTokenServices().createAccessToken(authentication);
getTokenStore().removeAccessToken(original);
assertEquals(0, getTokenStore().findTokensByClientId(authentication.getOAuth2Request().getClientId()).size());
}
@Test
public void testUnlimitedTokenExpiry() throws Exception {
getTokenServices().setAccessTokenValiditySeconds(0);
OAuth2AccessToken accessToken = getTokenServices().createAccessToken(createAuthentication());
assertEquals(0, accessToken.getExpiresIn());
assertEquals(null, accessToken.getExpiration());
}
@Test
public void testDefaultTokenExpiry() throws Exception {
getTokenServices().setAccessTokenValiditySeconds(100);
OAuth2AccessToken accessToken = getTokenServices().createAccessToken(createAuthentication());
assertTrue(100 >= accessToken.getExpiresIn());
}
@Test
public void testClientSpecificTokenExpiry() throws Exception {
getTokenServices().setAccessTokenValiditySeconds(1000);
getTokenServices().setClientDetailsService(new ClientDetailsService() {
public ClientDetails loadClientByClientId(String clientId) throws OAuth2Exception {
BaseClientDetails client = new BaseClientDetails();
client.setAccessTokenValiditySeconds(100);
return client;
}
});
OAuth2AccessToken accessToken = getTokenServices().createAccessToken(createAuthentication());
assertTrue(100 >= accessToken.getExpiresIn());
}
@Test
public void testRefreshedTokenHasScopes() throws Exception {
ExpiringOAuth2RefreshToken expectedExpiringRefreshToken = (ExpiringOAuth2RefreshToken) getTokenServices()
.createAccessToken(createAuthentication()).getRefreshToken();
TokenRequest tokenRequest = new TokenRequest(Collections.singletonMap("client_id", "id"), "id", null, null);
OAuth2AccessToken refreshedAccessToken = getTokenServices()
.refreshAccessToken(expectedExpiringRefreshToken.getValue(), tokenRequest);
assertEquals("[read, write]", refreshedAccessToken.getScope().toString());
}
@Test
public void testRefreshedTokenNotExpiring() throws Exception {
getTokenServices().setRefreshTokenValiditySeconds(0);
OAuth2RefreshToken expectedExpiringRefreshToken = getTokenServices().createAccessToken(createAuthentication())
.getRefreshToken();
assertFalse(expectedExpiringRefreshToken instanceof DefaultExpiringOAuth2RefreshToken);
}
@Test
public void testRevokedTokenNotAvailable() throws Exception {
OAuth2Authentication authentication = createAuthentication();
OAuth2AccessToken token = getTokenServices().createAccessToken(authentication);
getTokenServices().revokeToken(token.getValue());
Collection<OAuth2AccessToken> tokens = getTokenStore().findTokensByClientIdAndUserName(
authentication.getOAuth2Request().getClientId(), authentication.getUserAuthentication().getName());
assertFalse(tokens.contains(token));
assertTrue(tokens.isEmpty());
}
protected void configureTokenServices(DefaultTokenServices services) throws Exception {
services.setTokenStore(tokenStore);
services.setSupportRefreshToken(true);
services.afterPropertiesSet();
}
protected abstract TokenStore createTokenStore();
protected OAuth2Authentication createAuthentication() {
return new OAuth2Authentication(
RequestTokenFactory.createOAuth2Request(null, "id", null, false,
new LinkedHashSet<String>(Arrays.asList("read", "write")), null, null, null, null),
new TestAuthentication("test2", false));
}
protected TokenStore getTokenStore() {
return tokenStore;
}
protected DefaultTokenServices getTokenServices() {
return services;
}
protected static class TestAuthentication extends AbstractAuthenticationToken {
private static final long serialVersionUID = 1L;
private String principal;
public TestAuthentication(String name, boolean authenticated) {
super(null);
setAuthenticated(authenticated);
this.principal = name;
}
public Object getCredentials() {
return null;
}
public Object getPrincipal() {
return this.principal;
}
}
}