/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.web.context; import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import org.junit.After; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextImpl; public class SecurityContextPersistenceFilterTests { TestingAuthenticationToken testToken = new TestingAuthenticationToken("someone", "passwd", "ROLE_A"); @After public void clearContext() { SecurityContextHolder.clearContext(); } @Test public void contextIsClearedAfterChainProceeds() throws Exception { final FilterChain chain = mock(FilterChain.class); final MockHttpServletRequest request = new MockHttpServletRequest(); final MockHttpServletResponse response = new MockHttpServletResponse(); SecurityContextPersistenceFilter filter = new SecurityContextPersistenceFilter(); SecurityContextHolder.getContext().setAuthentication(testToken); filter.doFilter(request, response, chain); verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } @Test public void contextIsStillClearedIfExceptionIsThrowByFilterChain() throws Exception { final FilterChain chain = mock(FilterChain.class); final MockHttpServletRequest request = new MockHttpServletRequest(); final MockHttpServletResponse response = new MockHttpServletResponse(); SecurityContextPersistenceFilter filter = new SecurityContextPersistenceFilter(); SecurityContextHolder.getContext().setAuthentication(testToken); doThrow(new IOException()).when(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); try { filter.doFilter(request, response, chain); fail("IOException should have been thrown"); } catch (IOException expected) { } assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } @Test public void loadedContextContextIsCopiedToSecurityContextHolderAndUpdatedContextIsStored() throws Exception { final MockHttpServletRequest request = new MockHttpServletRequest(); final MockHttpServletResponse response = new MockHttpServletResponse(); final TestingAuthenticationToken beforeAuth = new TestingAuthenticationToken( "someoneelse", "passwd", "ROLE_B"); final SecurityContext scBefore = new SecurityContextImpl(); final SecurityContext scExpectedAfter = new SecurityContextImpl(); scExpectedAfter.setAuthentication(testToken); scBefore.setAuthentication(beforeAuth); final SecurityContextRepository repo = mock(SecurityContextRepository.class); SecurityContextPersistenceFilter filter = new SecurityContextPersistenceFilter( repo); when(repo.loadContext(any(HttpRequestResponseHolder.class))).thenReturn(scBefore); final FilterChain chain = new FilterChain() { public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { assertThat(SecurityContextHolder.getContext().getAuthentication()).isEqualTo(beforeAuth); // Change the context here SecurityContextHolder.setContext(scExpectedAfter); } }; filter.doFilter(request, response, chain); verify(repo).saveContext(scExpectedAfter, request, response); } @Test public void filterIsNotAppliedAgainIfFilterAppliedAttributeIsSet() throws Exception { final FilterChain chain = mock(FilterChain.class); final MockHttpServletRequest request = new MockHttpServletRequest(); final MockHttpServletResponse response = new MockHttpServletResponse(); SecurityContextPersistenceFilter filter = new SecurityContextPersistenceFilter( mock(SecurityContextRepository.class)); request.setAttribute(SecurityContextPersistenceFilter.FILTER_APPLIED, Boolean.TRUE); filter.doFilter(request, response, chain); verify(chain).doFilter(request, response); } @Test public void sessionIsEagerlyCreatedWhenConfigured() throws Exception { final FilterChain chain = mock(FilterChain.class); final MockHttpServletRequest request = new MockHttpServletRequest(); final MockHttpServletResponse response = new MockHttpServletResponse(); SecurityContextPersistenceFilter filter = new SecurityContextPersistenceFilter(); filter.setForceEagerSessionCreation(true); filter.doFilter(request, response, chain); assertThat(request.getSession(false)).isNotNull(); } @Test public void nullSecurityContextRepoDoesntSaveContextOrCreateSession() throws Exception { final FilterChain chain = mock(FilterChain.class); final MockHttpServletRequest request = new MockHttpServletRequest(); final MockHttpServletResponse response = new MockHttpServletResponse(); SecurityContextRepository repo = new NullSecurityContextRepository(); SecurityContextPersistenceFilter filter = new SecurityContextPersistenceFilter( repo); filter.doFilter(request, response, chain); assertThat(repo.containsContext(request)).isFalse(); assertThat(request.getSession(false)).isNull(); } }