/* * Copyright 2012-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.oauth2.provider.token.store.redis; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.springframework.data.redis.connection.jedis.JedisConnection; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2RefreshToken; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.RequestTokenFactory; import java.util.*; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; /** * * @author Joe Grandja * */ public class RedisTokenStoreMockTests { private JedisConnectionFactory connectionFactory; private JedisConnection connection; private RedisTokenStore tokenStore; private OAuth2Request request; private TestingAuthenticationToken authentication; @Before public void setUp() throws Exception { connectionFactory = mock(JedisConnectionFactory.class); connection = mock(JedisConnection.class); when(connectionFactory.getConnection()).thenReturn(connection); tokenStore = new RedisTokenStore(connectionFactory); Set<String> scope = new LinkedHashSet<String>(Arrays.asList("read", "write", "read-write")); request = RequestTokenFactory.createOAuth2Request("clientId", false, scope); authentication = new TestingAuthenticationToken("user", "password"); } // gh-572 @Test public void storeRefreshTokenRemoveRefreshTokenVerifyKeysRemoved() { OAuth2RefreshToken oauth2RefreshToken = new DefaultOAuth2RefreshToken("refresh-token-" + UUID.randomUUID()); OAuth2Authentication oauth2Authentication = new OAuth2Authentication(request, authentication); tokenStore.storeRefreshToken(oauth2RefreshToken, oauth2Authentication); ArgumentCaptor<byte[]> keyArgs = ArgumentCaptor.forClass(byte[].class); verify(connection, times(2)).set(keyArgs.capture(), any(byte[].class)); tokenStore.removeRefreshToken(oauth2RefreshToken); for (byte[] key : keyArgs.getAllValues()) { verify(connection).del(key); } } // gh-572 @Test public void storeAccessTokenWithoutRefreshTokenRemoveAccessTokenVerifyKeysRemoved() { OAuth2AccessToken oauth2AccessToken = new DefaultOAuth2AccessToken("access-token-" + UUID.randomUUID()); OAuth2Authentication oauth2Authentication = new OAuth2Authentication(request, authentication); List<Object> results = Arrays.<Object>asList("access-token".getBytes(), "authentication".getBytes()); when(connection.closePipeline()).thenReturn(results); RedisTokenStoreSerializationStrategy serializationStrategy = new JdkSerializationStrategy(); serializationStrategy = spy(serializationStrategy); when(serializationStrategy.deserialize(any(byte[].class), eq(OAuth2Authentication.class))).thenReturn(oauth2Authentication); tokenStore.setSerializationStrategy(serializationStrategy); tokenStore.storeAccessToken(oauth2AccessToken, oauth2Authentication); ArgumentCaptor<byte[]> setKeyArgs = ArgumentCaptor.forClass(byte[].class); verify(connection, times(3)).set(setKeyArgs.capture(), any(byte[].class)); ArgumentCaptor<byte[]> rPushKeyArgs = ArgumentCaptor.forClass(byte[].class); verify(connection, times(2)).rPush(rPushKeyArgs.capture(), any(byte[].class)); tokenStore.removeAccessToken(oauth2AccessToken); for (byte[] key : setKeyArgs.getAllValues()) { verify(connection).del(key); } for (byte[] key : rPushKeyArgs.getAllValues()) { verify(connection).lRem(eq(key), eq(1L), any(byte[].class)); } } }