/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.wicket.protocol.http; import static org.hamcrest.CoreMatchers.is; import javax.servlet.http.HttpServletRequest; import org.apache.wicket.RestartResponseException; import org.apache.wicket.protocol.http.CsrfPreventionRequestCycleListener.CsrfAction; import org.apache.wicket.request.IRequestHandler; import org.apache.wicket.request.component.IRequestablePage; import org.apache.wicket.request.http.WebRequest; import org.apache.wicket.util.tester.WicketTestCase; import org.junit.Before; import org.junit.Test; /** * Test cases for the CsrfPreventionRequestCycleListener. FirstPage has a link that when clicked * should render SecondPage. */ public class CsrfPreventionRequestCycleListenerTest extends WicketTestCase { /** * Sets up the test cases. Installs the CSRF listener and renders the FirstPage. */ @Before public void startWithFirstPageRender() { WebApplication application = tester.getApplication(); csrfListener = new MockCsrfPreventionRequestCycleListener(); setErrorCode(errorCode); setErrorMessage(errorMessage); application.getRequestCycleListeners().add(csrfListener); // Rendering a page is allowed, regardless of Origin (this allows external links into your // website to function) tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "https://google.com/"); tester.startPage(FirstPage.class); tester.assertRenderedPage(FirstPage.class); } /** Tests that disabling the CSRF listener doesn't check Origin headers. */ @Test public void disabledListenerDoesntCheckAnything() { csrfEnabled = false; tester.clickLink("link"); assertOriginsNotChecked(); tester.assertRenderedPage(SecondPage.class); } /** Tests that disabling the CSRF listener doesn't check Origin headers. */ @Test public void disabledListenerDoesntCheckMismatchedOrigin() { csrfEnabled = false; tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://malicioussite.com/"); tester.clickLink("link"); assertOriginsNotChecked(); tester.assertRenderedPage(SecondPage.class); } /** Tests the default setting of aborting a missing Origin. */ @Test public void withoutOriginAllowed() { csrfListener.setNoOriginAction(CsrfAction.ALLOW); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, null); tester.clickLink("link"); assertConflictingOriginsRequestAllowed(); } /** Tests the alternative action of suppressing a request without Origin header */ @Test public void withoutOriginSuppressed() { csrfListener.setNoOriginAction(CsrfAction.SUPPRESS); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, null); tester.clickLink("link"); tester.assertRenderedPage(FirstPage.class); assertConflictingOriginsRequestSuppressed(); } /** Tests the alternative action of aborting a request without Origin header */ @Test public void withoutOriginAborted() { csrfListener.setNoOriginAction(CsrfAction.ABORT); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, null); tester.clickLink("link"); assertConflictingOriginsRequestAborted(); } /** Tests when the Origin header matches the request. */ @Test public void matchingOriginsAllowed() { csrfListener.setConflictingOriginAction(CsrfAction.ALLOW); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://localhost/"); tester.clickLink("link"); assertOriginsMatched(); tester.assertRenderedPage(SecondPage.class); } /** Tests when the default action is changed to ALLOW when origins conflict. */ @Test public void conflictingOriginsAllowed() { csrfListener.setConflictingOriginAction(CsrfAction.ALLOW); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); tester.clickLink("link"); assertConflictingOriginsRequestAllowed(); tester.assertRenderedPage(SecondPage.class); } /** Tests when the default action is changed to SUPPRESS when origins conflict. */ @Test public void conflictingOriginsSuppressed() { tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); csrfListener.setConflictingOriginAction(CsrfAction.SUPPRESS); tester.clickLink("link"); assertConflictingOriginsRequestSuppressed(); tester.assertRenderedPage(FirstPage.class); } /** Tests the default action to ABORT when origins conflict. */ @Test public void conflictingOriginsAborted() { tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); tester.clickLink("link"); assertConflictingOriginsRequestAborted(); } /** Tests custom error code/message when the default action is ABORT. */ @Test public void conflictingOriginsAbortedWith401Unauhorized() { setErrorCode(401); setErrorMessage("NOT AUTHORIZED"); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); csrfListener.setNoOriginAction(CsrfAction.ABORT); tester.clickLink("link"); assertConflictingOriginsRequestAborted(); } /** Tests whitelisting for conflicting origins. */ @Test public void conflictingButWhitelistedOriginAllowed() { csrfListener.setConflictingOriginAction(CsrfAction.ALLOW); csrfListener.addAcceptedOrigin("example.com"); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); tester.clickLink("link"); assertOriginsWhitelisted(); tester.assertRenderedPage(SecondPage.class); } /** Tests whitelisting with conflicting subdomain origin. */ @Test public void conflictingButWhitelistedSubdomainOriginAllowed() { csrfListener.addAcceptedOrigin("example.com"); csrfListener.setConflictingOriginAction(CsrfAction.ALLOW); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://foo.example.com/"); tester.clickLink("link"); tester.assertRenderedPage(SecondPage.class); assertOriginsWhitelisted(); } /** * Tests when the listener is disabled for a specific page (by overriding * {@link CsrfPreventionRequestCycleListener#isChecked(IRequestablePage)}) */ @Test public void conflictingOriginPageNotCheckedAllowed() { tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); csrfListener.setConflictingOriginAction(CsrfAction.ABORT); // disable the check for this page checkPage = false; tester.clickLink("link"); assertConflictingOriginsRequestAllowed(); tester.assertRenderedPage(SecondPage.class); } /** Tests overriding the onSuppressed method for a conflicting origin. */ @Test public void conflictingOriginSuppressedCallsCustomHandler() { // redirect to third page to ensure we are not suppressed to the first page, nor that the // request was not suppressed and the second page was rendered erroneously Runnable thirdPageRedirect = new Runnable() { @Override public void run() { throw new RestartResponseException(new ThirdPage()); } }; setSuppressHandler(thirdPageRedirect); csrfListener.setConflictingOriginAction(CsrfAction.SUPPRESS); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); tester.clickLink("link"); assertConflictingOriginsRequestSuppressed(); tester.assertRenderedPage(ThirdPage.class); } /** Tests overriding the onAllowed method for a conflicting origin. */ @Test public void conflictingOriginAllowedCallsCustomHandler() { // redirect to third page to ensure we are not suppressed to the first page, nor that the // request was not allowed and the second page was rendered erroneously Runnable thirdPageRedirect = new Runnable() { @Override public void run() { throw new RestartResponseException(new ThirdPage()); } }; setAllowHandler(thirdPageRedirect); csrfListener.setConflictingOriginAction(CsrfAction.ALLOW); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); tester.clickLink("link"); assertConflictingOriginsRequestAllowed(); tester.assertRenderedPage(ThirdPage.class); } /** Tests overriding the onAborted method for a conflicting origin. */ @Test public void conflictingOriginAbortedCallsCustomHandler() { // redirect to third page to ensure we are not suppressed to the first page, nor that the // request was not aborted and the second page was rendered erroneously Runnable thirdPageRedirect = new Runnable() { @Override public void run() { throw new RestartResponseException(new ThirdPage()); } }; setAbortHandler(thirdPageRedirect); tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://example.com/"); csrfListener.setConflictingOriginAction(CsrfAction.ABORT); tester.clickLink("link"); // have to check manually, as the assert checks the error code (which is not set due to our // custom handler) if (!aborted) throw new AssertionError("Request was not aborted"); tester.assertRenderedPage(ThirdPage.class); } /** Tests whether a different port, but same scheme and hostname is considered a conflict. */ @Test public void differentPortOriginAborted() { tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://localhost:8080"); csrfListener.setConflictingOriginAction(CsrfAction.ABORT); tester.clickLink("link"); assertConflictingOriginsRequestAborted(); } /** Tests whether a different scheme, but same port and hostname is considered a conflict. */ @Test public void differentSchemeOriginAborted() { tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "https://localhost"); csrfListener.setConflictingOriginAction(CsrfAction.ABORT); tester.clickLink("link"); assertConflictingOriginsRequestAborted(); } /** Tests whether only the hostname is considered when matching the Origin header. */ @Test public void longerOriginAllowed() { tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://localhost/supercalifragilisticexpialidocious"); csrfListener.setConflictingOriginAction(CsrfAction.ABORT); tester.clickLink("link"); assertOriginsMatched(); tester.assertRenderedPage(SecondPage.class); } /** Tests whether AJAX Links are checked through the CSRF listener */ @Test public void simulatedCsrfAttackThroughAjaxIsPrevented() { csrfListener.setConflictingOriginAction(CsrfAction.ABORT); // first render a page in the user's session tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://localhost"); tester.startPage(ThirdPage.class); assertOriginsNotChecked(); tester.assertRenderedPage(ThirdPage.class); // then click on a link from another external page tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://attacker.com/"); tester.clickLink("link", true); assertConflictingOriginsRequestAborted(); } /** Tests whether AJAX Links are checked through the CSRF listener */ @Test public void simulatedCsrfAttackIsSuppressed() { csrfListener.setConflictingOriginAction(CsrfAction.SUPPRESS); // first render a page in the user's session tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://localhost"); tester.startPage(ThirdPage.class); assertOriginsNotChecked(); tester.assertRenderedPage(ThirdPage.class); // then click on a link from another external page tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://attacker.com/"); tester.clickLink("link", true); assertConflictingOriginsRequestSuppressed(); tester.assertRenderedPage(ThirdPage.class); } /** Tests whether form submits are checked through the CSRF listener */ @Test public void simulatedCsrfAttackOnFormIsSuppressed() { csrfListener.setConflictingOriginAction(CsrfAction.SUPPRESS); // first render a page in the user's session tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://localhost"); tester.startPage(ThirdPage.class); assertOriginsNotChecked(); tester.assertRenderedPage(ThirdPage.class); // then click on a link from another external page tester.addRequestHeader(WebRequest.HEADER_ORIGIN, "http://attacker.com/"); tester.submitForm("form"); assertConflictingOriginsRequestSuppressed(); tester.assertRenderedPage(ThirdPage.class); } /* * Infrastructure code for these test cases starts here. */ /** The listener under test */ private CsrfPreventionRequestCycleListener csrfListener; /** Flag for enabling/disabling the CSRF listener */ private boolean csrfEnabled = true; /** Flag for enabling/disabling the page check of the CSRF listener */ private boolean checkPage = true; /** Value for reporting the error code when the request was aborted */ private int errorCode = 400; /** Value for reporting the error message when the request was aborted */ private String errorMessage = "BAD REQUEST"; /** Checks for asserting the functionality of the CSRF listener */ private boolean matched, whitelisted, aborted, allowed, suppressed; /** * Manner to override the default check whether the current request handler should be checked * for CSRF attacks. */ private Predicate<IRequestHandler> customRequestHandlerCheck; /** * Handlers for specific tests (ensures that the listener calls the right handler in the right * circumstance. */ private Runnable abortHandler, allowHandler, suppressHandler, matchedHandler, whitelistHandler; private void setErrorCode(int errorCode) { this.errorCode = errorCode; csrfListener.setErrorCode(errorCode); } private void setCustomRequestHandlerCheck(Predicate<IRequestHandler> check) { this.customRequestHandlerCheck = check; } private void setErrorMessage(String errorMessage) { this.errorMessage = errorMessage; csrfListener.setErrorMessage(errorMessage); } private void setAbortHandler(Runnable abortHandler) { this.abortHandler = abortHandler; } private void setAllowHandler(Runnable allowHandler) { this.allowHandler = allowHandler; } private void setSuppressHandler(Runnable suppressHandler) { this.suppressHandler = suppressHandler; } private void setWhitelistHandler(Runnable whitelistHandler) { this.whitelistHandler = whitelistHandler; } private void setMatchedHandler(Runnable matchedHandler) { this.matchedHandler = matchedHandler; } /** * Asserts that the origins were checked, and found matching. */ private void assertOriginsMatched() { if (!matched) throw new AssertionError("Origins were not matched"); } /** * Asserts that the origins were not checked, because the origin was on the whitelist. */ private void assertOriginsWhitelisted() { if (!whitelisted) throw new AssertionError("Origins were not whitelisted"); } /** * Asserts that the origins were checked, found conflicting, had an action "ABORTED" and returns * a HTTP error. */ private void assertConflictingOriginsRequestAborted() { if (!aborted) throw new AssertionError("Request was not aborted"); assertThat("Response error code", tester.getLastResponse().getStatus(), is(errorCode)); assertThat("Response error message", tester.getLastResponse().getErrorMessage(), is(errorMessage)); } /** * Asserts that the origins were checked, found conflicting and had an action "SUPPRESS". */ private void assertConflictingOriginsRequestSuppressed() { if (!suppressed) throw new AssertionError("Request was not suppressed"); } /** * Asserts that the origins were checked, found conflicting and had an action "ALLOWED". */ private void assertConflictingOriginsRequestAllowed() { if (!allowed) throw new AssertionError("Request was not allowed"); } /** * Asserts that the origins were checked and found non-conflicting. */ private void assertOriginsCheckedButNotConflicting() { if (aborted) throw new AssertionError("Origin was checked and aborted"); if (suppressed) throw new AssertionError("Origin was checked and suppressed"); if (allowed) throw new AssertionError("Origin was checked and allowed"); if (whitelisted) throw new AssertionError("Origin was whitelisted"); if (!matched) throw new AssertionError("Origin was not checked"); } /** * Asserts that no check was performed at all. */ private void assertOriginsNotChecked() { if (aborted) throw new AssertionError("Request was checked and aborted"); if (suppressed) throw new AssertionError("Request was checked and suppressed"); if (allowed) throw new AssertionError("Request was checked and allowed"); if (whitelisted) throw new AssertionError("Origin was whitelisted"); if (matched) throw new AssertionError("Origin was checked and matched"); } private final class MockCsrfPreventionRequestCycleListener extends CsrfPreventionRequestCycleListener { @Override protected boolean isEnabled() { return csrfEnabled; } @Override protected boolean isChecked(IRequestHandler handler) { if (customRequestHandlerCheck != null) return customRequestHandlerCheck.apply(handler); return super.isChecked(handler); } @Override protected boolean isChecked(IRequestablePage targetedPage) { return checkPage; } @Override protected void onAborted(HttpServletRequest containerRequest, String origin, IRequestablePage page) { aborted = true; if (abortHandler != null) abortHandler.run(); } @Override protected void onAllowed(HttpServletRequest containerRequest, String origin, IRequestablePage page) { allowed = true; if (allowHandler != null) allowHandler.run(); } @Override protected void onSuppressed(HttpServletRequest containerRequest, String origin, IRequestablePage page) { suppressed = true; if (suppressHandler != null) suppressHandler.run(); } @Override protected void onMatchingOrigin(HttpServletRequest containerRequest, String origin, IRequestablePage page) { matched = true; if (matchedHandler != null) matchedHandler.run(); } @Override protected void onWhitelisted(HttpServletRequest containerRequest, String origin, IRequestablePage page) { whitelisted = true; if (whitelistHandler != null) whitelistHandler.run(); } } // Remove when migration to Java 8 is completed private interface Predicate<T> { boolean apply(T t); } }