/* * Copyright 2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.social.security; import static org.junit.Assert.*; import static org.mockito.Mockito.*; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Set; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockServletContext; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.social.UserIdSource; import org.springframework.social.connect.Connection; import org.springframework.social.connect.ConnectionData; import org.springframework.social.connect.ConnectionFactory; import org.springframework.social.connect.ConnectionRepository; import org.springframework.social.connect.UsersConnectionRepository; import org.springframework.social.security.provider.SocialAuthenticationService; import org.springframework.social.security.provider.SocialAuthenticationService.ConnectionCardinality; import org.springframework.social.security.test.DummyConnection; import org.springframework.social.security.test.MockConnectionFactory; public class SocialAuthenticationFilterTest { @Before @After public void clean() { SecurityContextHolder.getContext().setAuthentication(null); } @SuppressWarnings("unchecked") @Test public void testExplicitAuth() throws Exception { FilterTestEnv env = new FilterTestEnv("GET", "/auth", null); env.filter.setFilterProcessesUrl(env.req.getRequestURI()); env.filter.setPostLoginUrl("/success"); ConnectionFactory<Object> factory = mock(MockConnectionFactory.class); when(factory.getProviderId()).thenReturn("mock"); env.req.setRequestURI(env.req.getRequestURI() + "/" + factory.getProviderId()); SocialAuthenticationService<Object> authService = mock(SocialAuthenticationService.class); when(authService.getConnectionCardinality()).thenReturn(ConnectionCardinality.ONE_TO_ONE); when(authService.getConnectionFactory()).thenReturn(factory); when(authService.getAuthToken(env.req, env.res)).thenReturn(env.auth); env.addAuthService(authService); when(env.authManager.authenticate(env.auth)).thenReturn(env.authSuccess); assertNull(SecurityContextHolder.getContext().getAuthentication()); env.doFilter(); assertNotNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals("/success", env.res.getRedirectedUrl()); } @Test public void testFailedAuth_slashRegister() throws Exception { FilterTestEnv env = new FilterTestEnv("GET", "/auth", "/register"); testFailedAuth(env); } @Test public void testFailedAuth_register() throws Exception { FilterTestEnv env = new FilterTestEnv("GET", "/auth", "register"); testFailedAuth(env); } @Test public void testFailedAuth_fullyQualifiedUrlRegister() throws Exception { FilterTestEnv env = new FilterTestEnv("GET", "/auth", "http://localhost/register"); testFailedAuth(env); } @SuppressWarnings("unchecked") private void testFailedAuth(FilterTestEnv env) throws Exception { env.filter.setFilterProcessesUrl(env.req.getRequestURI()); env.filter.setPostLoginUrl("/success"); ConnectionFactory<Object> factory = mock(MockConnectionFactory.class); when(factory.getProviderId()).thenReturn("mock"); env.req.setRequestURI(env.req.getRequestURI() + "/" + factory.getProviderId()); SocialAuthenticationService<Object> authService = mock(SocialAuthenticationService.class); when(authService.getConnectionCardinality()).thenReturn(ConnectionCardinality.ONE_TO_ONE); when(authService.getConnectionFactory()).thenReturn(factory); when(authService.getAuthToken(env.req, env.res)).thenReturn(env.auth); env.addAuthService(authService); when(env.authManager.authenticate(env.auth)).thenThrow(new BadCredentialsException("Failed")); assertNull(SecurityContextHolder.getContext().getAuthentication()); env.doFilter(); assertNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals("http://localhost/register", env.res.getRedirectedUrl()); } @SuppressWarnings("unchecked") @Test public void addConnection() { UsersConnectionRepository usersConnectionRepository = mock(UsersConnectionRepository.class); SocialAuthenticationFilter filter = new SocialAuthenticationFilter(null, null, usersConnectionRepository, null); SocialAuthenticationService<Object> authService = mock(SocialAuthenticationService.class); ConnectionRepository connectionRepository = mock(ConnectionRepository.class); ConnectionFactory<Object> connectionFactory = mock(MockConnectionFactory.class); MockHttpServletRequest request = new MockHttpServletRequest(); ConnectionData data = new ConnectionData("dummyprovider", "1234", null, null, null, null, null, null, null); String userId = "joe"; DummyConnection<Object> connection = DummyConnection.dummy(data.getProviderId(), userId); when(usersConnectionRepository.findUserIdsConnectedTo(data.getProviderId(), set(data.getProviderUserId()))).thenReturn(empty(String.class)); when(usersConnectionRepository.createConnectionRepository(userId)).thenReturn(connectionRepository); when(authService.getConnectionCardinality()).thenReturn(ConnectionCardinality.ONE_TO_ONE); when(authService.getConnectionFactory()).thenReturn(connectionFactory); when(authService.getConnectionAddedRedirectUrl(request, connection)).thenReturn("/redirect"); when(connectionFactory.createConnection(data)).thenReturn(connection); Connection<?> addedConnection = filter.addConnection(authService, userId, data); assertNotNull(addedConnection); assertSame(connection, addedConnection); verify(connectionRepository).addConnection(connection); } @SuppressWarnings("unchecked") @Test public void addConnection_authenticated() throws Exception { FilterTestEnv env = new FilterTestEnv("GET", "/auth", null); env.filter.setFilterProcessesUrl(env.req.getRequestURI()); env.filter.setPostLoginUrl("/success"); env.filter.setConnectionAddedRedirectUrl("/added"); env.filter.setConnectionAddingFailureRedirectUrl("/add-failed"); Connection<?> connection = env.auth.getConnection(); ConnectionData data = connection.createData(); String userId = "joe"; ConnectionFactory<Object> factory = mock(MockConnectionFactory.class); when(factory.getProviderId()).thenReturn("mock"); when(factory.createConnection(data)).thenReturn((Connection<Object>) connection); env.req.setRequestURI(env.req.getRequestURI() + "/" + factory.getProviderId()); SocialAuthenticationService<Object> authService = mock(SocialAuthenticationService.class); when(authService.getConnectionCardinality()).thenReturn(ConnectionCardinality.ONE_TO_ONE); when(authService.getConnectionFactory()).thenReturn(factory); when(authService.getAuthToken(env.req, env.res)).thenReturn(env.auth); env.addAuthService(authService); when(env.userIdSource.getUserId()).thenReturn(userId); when(env.usersConnectionRepository.findUserIdsConnectedTo(data.getProviderId(), set(data.getProviderUserId()))).thenReturn(empty(String.class)); // fallback to default /added when(authService.getConnectionAddedRedirectUrl(env.req, connection)).thenReturn(null); // already authenticated SecurityContextHolder.getContext().setAuthentication(env.authSuccess); env.doFilter(); // still authenticated assertSame(env.authSuccess, SecurityContextHolder.getContext().getAuthentication()); assertEquals("/added", env.res.getRedirectedUrl()); verify(env.connectionRepository).addConnection(env.auth.getConnection()); } @SuppressWarnings("unchecked") @Test public void addConnection_authenticated_failed() throws Exception { FilterTestEnv env = new FilterTestEnv("GET", "/auth", null); env.filter.setFilterProcessesUrl(env.req.getRequestURI()); env.filter.setPostLoginUrl("/success"); env.filter.setConnectionAddedRedirectUrl("/added"); env.filter.setConnectionAddingFailureRedirectUrl("/add-failed"); Connection<?> connection = env.auth.getConnection(); ConnectionData data = connection.createData(); String userId = "joe"; ConnectionFactory<Object> factory = mock(MockConnectionFactory.class); when(factory.getProviderId()).thenReturn("mock"); when(factory.createConnection(data)).thenReturn((Connection<Object>) connection); env.req.setRequestURI(env.req.getRequestURI() + "/" + factory.getProviderId()); SocialAuthenticationService<Object> authService = mock(SocialAuthenticationService.class); when(authService.getConnectionCardinality()).thenReturn(ConnectionCardinality.ONE_TO_ONE); when(authService.getConnectionFactory()).thenReturn(factory); when(authService.getAuthToken(env.req, env.res)).thenReturn(env.auth); env.addAuthService(authService); when(env.userIdSource.getUserId()).thenReturn(userId); // already connected when(env.usersConnectionRepository.findUserIdsConnectedTo(data.getProviderId(), set(data.getProviderUserId()))).thenReturn(set(userId)); // already authenticated SecurityContextHolder.getContext().setAuthentication(env.authSuccess); env.doFilter(); // still authenticated assertSame(env.authSuccess, SecurityContextHolder.getContext().getAuthentication()); assertEquals("/add-failed", env.res.getRedirectedUrl()); verify(env.connectionRepository, times(0)).addConnection(env.auth.getConnection()); } private static <T> Set<T> empty(Class<T> cls) { return Collections.emptySet(); } private static Set<String> set(String ... o) { return Collections.unmodifiableSet(new HashSet<String>(Arrays.asList(o))); } private static class FilterTestEnv { private final SocialAuthenticationFilter filter; private final MockServletContext context; private final MockHttpServletRequest req; private final MockHttpServletResponse res; private final MockFilterChain chain; private final MockFilterConfig config = new MockFilterConfig(); private final SocialAuthenticationToken auth; private final SocialAuthenticationToken authSuccess; private final AuthenticationManager authManager; private final UserIdSource userIdSource; private final UsersConnectionRepository usersConnectionRepository; private final ConnectionRepository connectionRepository; private FilterTestEnv(String method, String requestURI, String signupUrl) { context = new MockServletContext(); req = new MockHttpServletRequest(context, method, requestURI); res = new MockHttpServletResponse(); chain = new MockFilterChain(); authManager = mock(AuthenticationManager.class); userIdSource = mock(UserIdSource.class); usersConnectionRepository = mock(UsersConnectionRepository.class); connectionRepository = mock(ConnectionRepository.class); filter = new SocialAuthenticationFilter(authManager, userIdSource, usersConnectionRepository, new SocialAuthenticationServiceRegistry()); filter.setServletContext(context); filter.setRememberMeServices(new NullRememberMeServices()); filter.setSignupUrl(signupUrl); when(filter.getUsersConnectionRepository().createConnectionRepository(Mockito.anyString())).thenReturn(connectionRepository); auth = new SocialAuthenticationToken(DummyConnection.dummy("provider", "user"), null); Collection<? extends GrantedAuthority> authorities = Collections.emptyList(); User user = new SocialUser("foo", "bar", authorities); authSuccess = new SocialAuthenticationToken(DummyConnection.dummy("provider", "user"), user, null, authorities); } private void addAuthService(SocialAuthenticationService<?> authenticationService) { ((SocialAuthenticationServiceRegistry)filter.getAuthServiceLocator()).addAuthenticationService(authenticationService); } private void doFilter() throws Exception { filter.init(config); filter.doFilter(req, res, chain); filter.destroy(); } } }