package org.springframework.security.oauth.consumer.filter; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.springframework.security.oauth.common.OAuthProviderParameter; import org.springframework.security.oauth.consumer.AccessTokenRequiredException; import org.springframework.security.oauth.consumer.BaseProtectedResourceDetails; import org.springframework.security.oauth.consumer.OAuthConsumerSupport; import org.springframework.security.oauth.consumer.OAuthConsumerToken; import org.springframework.security.oauth.consumer.OAuthSecurityContextHolder; import org.springframework.security.oauth.consumer.ProtectedResourceDetails; import org.springframework.security.oauth.consumer.rememberme.NoOpOAuthRememberMeServices; import org.springframework.security.oauth.consumer.rememberme.OAuthRememberMeServices; import org.springframework.security.oauth.consumer.token.OAuthConsumerTokenServices; import org.springframework.security.web.RedirectStrategy; /** * @author Ryan Heaton */ @RunWith(MockitoJUnitRunner.class) public class OAuthConsumerContextFilterTests { @Mock private ProtectedResourceDetails details; @Mock private HttpServletRequest request; @Mock private HttpServletResponse response; @Mock private FilterChain filterChain; @Mock private OAuthConsumerTokenServices tokenServices; @Mock private OAuthConsumerSupport support; /** * tests getting the user authorization redirect URL. */ @Test public void testGetUserAuthorizationRedirectURL() throws Exception { OAuthConsumerContextFilter filter = new OAuthConsumerContextFilter(); OAuthConsumerToken token = new OAuthConsumerToken(); token.setResourceId("resourceId"); token.setValue("mytoken"); when(details.getUserAuthorizationURL()).thenReturn("http://user-auth/context?with=some&queryParams"); when(details.isUse10a()).thenReturn(false); assertEquals( "http://user-auth/context?with=some&queryParams&oauth_token=mytoken&oauth_callback=urn%3A%2F%2Fcallback%3Fwith%3Dsome%26query%3Dparams", filter.getUserAuthorizationRedirectURL(details, token, "urn://callback?with=some&query=params")); when(details.getUserAuthorizationURL()).thenReturn("http://user-auth/context?with=some&queryParams"); when(details.isUse10a()).thenReturn(true); assertEquals("http://user-auth/context?with=some&queryParams&oauth_token=mytoken", filter.getUserAuthorizationRedirectURL(details, token, "urn://callback?with=some&query=params")); } /** * tests the filter. */ @Test public void testDoFilter() throws Exception { final OAuthRememberMeServices rememberMeServices = new NoOpOAuthRememberMeServices(); final BaseProtectedResourceDetails resource = new BaseProtectedResourceDetails(); resource.setId("dep1"); OAuthConsumerContextFilter filter = new OAuthConsumerContextFilter() { @Override protected String getCallbackURL(HttpServletRequest request) { return "urn:callback"; } @Override protected String getUserAuthorizationRedirectURL(ProtectedResourceDetails details, OAuthConsumerToken requestToken, String callbackURL) { return callbackURL + "&" + requestToken.getResourceId(); } }; filter.setRedirectStrategy(new RedirectStrategy() { public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException { response.sendRedirect(url); } }); filter.setTokenServices(tokenServices); filter.setConsumerSupport(support); filter.setRememberMeServices(rememberMeServices); doThrow(new AccessTokenRequiredException(resource)).when(filterChain).doFilter(request, response); when(tokenServices.getToken("dep1")).thenReturn(null); when(request.getParameter("oauth_verifier")).thenReturn(null); when(response.encodeRedirectURL("urn:callback")).thenReturn("urn:callback?query"); OAuthConsumerToken token = new OAuthConsumerToken(); token.setAccessToken(false); token.setResourceId(resource.getId()); when(support.getUnauthorizedRequestToken("dep1", "urn:callback?query")).thenReturn(token); filter.doFilter(request, response, filterChain); verify(filterChain).doFilter(request, response); verify(tokenServices).storeToken("dep1", token); verify(response).sendRedirect("urn:callback?query&dep1"); verify(request,times(2)).setAttribute(anyString(), anyObject()); reset(request,response,filterChain); doThrow(new AccessTokenRequiredException(resource)).when(filterChain).doFilter(request, response); when(tokenServices.getToken("dep1")).thenReturn(token); when(request.getParameter(OAuthProviderParameter.oauth_verifier.toString())).thenReturn("verifier"); OAuthConsumerToken accessToken = new OAuthConsumerToken(); when(support.getAccessToken(token, "verifier")).thenReturn(accessToken); when(response.isCommitted()).thenReturn(false); filter.doFilter(request, response, filterChain); verify(filterChain,times(2)).doFilter(request, response); verify(tokenServices).removeToken("dep1"); verify(tokenServices).storeToken("dep1", accessToken); verify(request,times(2)).setAttribute(anyString(), anyObject()); } @After public void tearDown() throws Exception { OAuthSecurityContextHolder.setContext(null); } }