/* * Copyright 2002-2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.web.csrf; import java.io.IOException; import java.util.Arrays; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.assertj.core.api.AbstractObjectAssert; import org.assertj.core.api.ObjectAssert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; /** * @author Rob Winch * */ @RunWith(MockitoJUnitRunner.class) public class CsrfFilterTests { @Mock private RequestMatcher requestMatcher; @Mock private CsrfTokenRepository tokenRepository; @Mock private FilterChain filterChain; @Mock private AccessDeniedHandler deniedHandler; private MockHttpServletRequest request; private MockHttpServletResponse response; private CsrfToken token; private CsrfFilter filter; @Before public void setup() { this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); resetRequestResponse(); this.filter = createCsrfFilter(this.tokenRepository); } private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { CsrfFilter filter = new CsrfFilter(repository); filter.setRequireCsrfProtectionMatcher(this.requestMatcher); filter.setAccessDeniedHandler(this.deniedHandler); return filter; } private void resetRequestResponse() { this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); } @Test(expected = IllegalArgumentException.class) public void constructorNullRepository() { new CsrfFilter(null); } // SEC-2276 @Test public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException, IOException { this.filter = createCsrfFilter(new LazyCsrfTokenRepository(this.tokenRepository)); when(this.requestMatcher.matches(this.request)).thenReturn(false); when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); CsrfToken attrToken = (CsrfToken) this.request .getAttribute(this.token.getParameterName()); // no CsrfToken should have been saved yet verify(this.tokenRepository, times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(this.filterChain).doFilter(this.request, this.response); // access the token attrToken.getToken(); // now the CsrfToken should have been saved verify(this.tokenRepository).saveToken(eq(this.token), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyZeroInteractions(this.filterChain); } @Test public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyZeroInteractions(this.filterChain); } @Test public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyZeroInteractions(this.filterChain); } @Test public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyZeroInteractions(this.filterChain); } @Test public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(false); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(false); when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyZeroInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.generateToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertToken(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertToken(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())) .isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); verify(this.tokenRepository).saveToken(this.token, this.request, this.response); verifyZeroInteractions(this.deniedHandler); } @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.filterChain).doFilter(this.request, this.response); verifyZeroInteractions(this.deniedHandler); } } /** * SEC-2292 Should not allow other cases through since spec states HTTP method is case * sensitive http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.1 * @throws Exception if an error occurs * */ @Test public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { resetRequestResponse(); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyZeroInteractions(this.filterChain); } } @Test public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { resetRequestResponse(); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyZeroInteractions(this.filterChain); } } @Test public void doFilterDefaultAccessDenied() throws ServletException, IOException { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(this.token.getParameterName())) .isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())) .isEqualTo(this.token); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyZeroInteractions(this.filterChain); } @Test(expected = IllegalArgumentException.class) public void setRequireCsrfProtectionMatcherNull() { this.filter.setRequireCsrfProtectionMatcher(null); } @Test(expected = IllegalArgumentException.class) public void setAccessDeniedHandlerNull() { this.filter.setAccessDeniedHandler(null); } private static final CsrfTokenAssert assertToken(Object token) { return new CsrfTokenAssert((CsrfToken) token); } private static class CsrfTokenAssert extends AbstractObjectAssert<CsrfTokenAssert, CsrfToken> { /** * Creates a new </code>{@link ObjectAssert}</code>. * * @param actual the target to verify. */ protected CsrfTokenAssert(CsrfToken actual) { super(actual, CsrfTokenAssert.class); } public CsrfTokenAssert isEqualTo(CsrfToken expected) { assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName()); assertThat(this.actual.getParameterName()) .isEqualTo(expected.getParameterName()); assertThat(this.actual.getToken()).isEqualTo(expected.getToken()); return this; } } }