/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.openid; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; import java.net.URI; import java.util.Collections; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; public class OpenIDAuthenticationFilterTests { OpenIDAuthenticationFilter filter; private static final String REDIRECT_URL = "http://www.example.com/redirect"; private static final String CLAIMED_IDENTITY_URL = "http://www.example.com/identity"; private static final String REQUEST_PATH = "/login/openid"; private static final String FILTER_PROCESS_URL = "http://localhost:8080" + REQUEST_PATH; private static final String DEFAULT_TARGET_URL = FILTER_PROCESS_URL; @Before public void setUp() throws Exception { filter = new OpenIDAuthenticationFilter(); filter.setConsumer(new MockOpenIDConsumer(REDIRECT_URL)); SavedRequestAwareAuthenticationSuccessHandler successHandler = new SavedRequestAwareAuthenticationSuccessHandler(); filter.setAuthenticationSuccessHandler(new SavedRequestAwareAuthenticationSuccessHandler()); successHandler.setDefaultTargetUrl(DEFAULT_TARGET_URL); filter.setAuthenticationManager(new AuthenticationManager() { public Authentication authenticate(Authentication a) { return a; } }); filter.afterPropertiesSet(); } @Test public void testFilterOperation() throws Exception { MockHttpServletRequest req = new MockHttpServletRequest(); req.setServletPath(REQUEST_PATH); req.setRequestURI(REQUEST_PATH); req.setServerPort(8080); MockHttpServletResponse response = new MockHttpServletResponse(); req.setParameter("openid_identifier", " " + CLAIMED_IDENTITY_URL); req.setRemoteHost("www.example.com"); filter.setConsumer(new MockOpenIDConsumer() { public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) throws OpenIDConsumerException { assertThat(claimedIdentity).isEqualTo(CLAIMED_IDENTITY_URL); assertThat(returnToUrl).isEqualTo(DEFAULT_TARGET_URL); assertThat(realm).isEqualTo("http://localhost:8080/"); return REDIRECT_URL; } }); FilterChain fc = mock(FilterChain.class); filter.doFilter(req, response, fc); assertThat(response.getRedirectedUrl()).isEqualTo(REDIRECT_URL); // Filter chain shouldn't proceed verify(fc, never()).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } /** * Tests that the filter encodes any query parameters on the return_to URL. */ @Test public void encodesUrlParameters() throws Exception { // Arbitrary parameter name and value that will both need to be encoded: String paramName = "foo&bar"; String paramValue = "http://example.com/path?a=b&c=d"; MockHttpServletRequest req = new MockHttpServletRequest("GET", REQUEST_PATH); req.addParameter(paramName, paramValue); filter.setReturnToUrlParameters(Collections.singleton(paramName)); URI returnTo = new URI(filter.buildReturnToUrl(req)); String query = returnTo.getRawQuery(); assertThat(count(query, '=')).isEqualTo(1); assertThat(count(query, '&')).isEqualTo(0); } /** * Counts the number of occurrences of {@code c} in {@code s}. */ private static int count(String s, char c) { int count = 0; for (char ch : s.toCharArray()) { if (c == ch) { count += 1; } } return count; } }