/*************************GO-LICENSE-START********************************* * Copyright 2014 ThoughtWorks, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *************************GO-LICENSE-END***********************************/ package com.thoughtworks.go.server.web; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import com.thoughtworks.go.presentation.FlashMessageModel; import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.core.Is.is; import static org.hamcrest.core.IsNot.not; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; public class FlashLoadingFilterIntegrationTest { private FlashLoadingFilter filter; private FlashMessageModel flash; private FlashMessageService service; private String messageKey; @Before public void setUp() throws ServletException { service = new FlashMessageService(); filter = new FlashLoadingFilter(); filter.init(null); flash = null; messageKey = null; } @Test public void shouldInitializeFlashIfNotPresent() throws IOException, ServletException { MockHttpServletRequest req = new MockHttpServletRequest(); MockHttpServletResponse res = new MockHttpServletResponse(); FilterChain filterChain = new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) { messageKey = service.add(new FlashMessageModel("my message", "error")); flash = service.get(messageKey); } }; assertThat(messageKey, is(nullValue())); filter.doFilter(req, res, filterChain); assertThat(messageKey, not(nullValue())); assertThat(flash.toString(), is("my message")); assertThat(flash.getFlashClass(), is("error")); } @Test public void shouldLoadExistingFlashFromSession() throws IOException, ServletException { MockHttpServletRequest req = new MockHttpServletRequest(); MockHttpSession session = new MockHttpSession(); FlashMessageService.Flash oldFlash = new FlashMessageService.Flash(); oldFlash.put("my_key", new FlashMessageModel("my other message", "warning")); session.putValue(FlashLoadingFilter.FLASH_SESSION_KEY, oldFlash); req.setSession(session); MockHttpServletResponse res = new MockHttpServletResponse(); FilterChain filterChain = new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) { flash = service.get("my_key"); } }; filter.doFilter(req, res, filterChain); assertThat(flash.toString(), is("my other message")); assertThat(flash.getFlashClass(), is("warning")); } @Test public void shouldClearThreadContext() throws IOException, ServletException { MockHttpServletRequest req = new MockHttpServletRequest(); MockHttpServletResponse res = new MockHttpServletResponse(); FilterChain filterChain = new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) { messageKey = service.add(new FlashMessageModel("my message", "error")); flash = service.get(messageKey); } }; filter.doFilter(req, res, filterChain); assertThat(flash.toString(), is("my message")); try { service.get(messageKey); fail("attempt to load flash message should fail, as no thread local is cleared out"); } catch (Exception e) { assertThat(e.getMessage(), is("No flash context found, this call should only be made within a request.")); } } @Test public void shouldClearThreadContextInCaseOfExceptionAsWell() throws IOException, ServletException { MockHttpServletRequest req = new MockHttpServletRequest(); MockHttpServletResponse res = new MockHttpServletResponse(); FilterChain filterChain = new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) { messageKey = service.add(new FlashMessageModel("my message", "error")); flash = service.get(messageKey); throw new RuntimeException("exception here"); } }; try { filter.doFilter(req, res, filterChain); fail("exception gobbled"); } catch (Exception e) { assertThat(e.getMessage(), is("exception here")); } assertThat(flash.toString(), is("my message")); try { service.get(messageKey); fail("attempt to load flash message should fail, as no thread local is cleared out"); } catch (RuntimeException e) { assertThat(e.getMessage(), is("No flash context found, this call should only be made within a request.")); } } @Test public void shouldFailForNonHttpReqeusts() throws IOException, ServletException { try { filter.doFilter(mock(ServletRequest.class), mock(ServletResponse.class), new MockFilterChain()); fail("should not process non HTTP requests"); } catch (Exception e) { assertThat(e.getMessage(), containsString("cannot be cast to javax.servlet.http.HttpServletRequest")); } } }