package com.wesabe.servlet.tests; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; import static org.junit.matchers.JUnitMatchers.*; import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; import java.util.Collections; import java.util.Enumeration; import java.util.List; import javax.servlet.RequestDispatcher; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import org.junit.Before; import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.wesabe.servlet.BadRequestException; import com.wesabe.servlet.SafeRequest; @RunWith(Enclosed.class) public class SafeRequestTest { private static <E> List<E> enumerationToList(Enumeration<E> enumeration) { List<E> items = Lists.newLinkedList(); while (enumeration.hasMoreElements()) { items.add(enumeration.nextElement()); } return items; } private static abstract class Context { protected SafeRequest request; protected HttpServletRequest servletRequest; public void setup() throws Exception { this.servletRequest = mock(HttpServletRequest.class); this.request = new SafeRequest(servletRequest); } } public static class Getting_The_Method extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itNormalizesTheMethodName() throws Exception { when(servletRequest.getMethod()).thenReturn("get"); assertThat(request.getMethod(), is("GET")); verify(servletRequest).getMethod(); } @Test public void itThrowsABadRequestExceptionIfTheMethodIsInvalid() throws Exception { when(servletRequest.getMethod()).thenReturn("poop"); try { request.getMethod(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Scheme extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itNormalizesTheScheme() throws Exception { when(servletRequest.getScheme()).thenReturn("http"); assertThat(request.getScheme(), is("http")); verify(servletRequest).getScheme(); } @Test public void itThrowsABadRequestExceptionIfTheSchemeIsInvalid() throws Exception { when(servletRequest.getScheme()).thenReturn("poop"); try { request.getScheme(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Server_Port extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itNormalizesTheServerPort() throws Exception { when(servletRequest.getServerPort()).thenReturn(80); assertThat(request.getServerPort(), is(80)); verify(servletRequest).getServerPort(); } @Test public void itThrowsABadRequestExceptionIfTheSchemeIsInvalid() throws Exception { when(servletRequest.getServerPort()).thenReturn(1112228888); try { request.getServerPort(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_A_Date_Header extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itReturnsAnyParsedDateHeader() throws Exception { when(servletRequest.getDateHeader("Last-Modified")).thenReturn(80L); assertThat(request.getDateHeader("Last-Modified"), is(80L)); verify(servletRequest).getDateHeader("Last-Modified"); } @Test public void itWrapsAFailedParseInABadRequestException() throws Exception { when(servletRequest.getDateHeader("Last-Modified")).thenThrow(new IllegalArgumentException("no")); try { request.getDateHeader("Last-Modified"); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); assertThat(e.getCause(), is(IllegalArgumentException.class)); } } } public static class Getting_An_Int_Header extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itReturnsAnyParsedIntHeader() throws Exception { when(servletRequest.getIntHeader("Age")).thenReturn(200); assertThat(request.getIntHeader("Age"), is(200)); verify(servletRequest).getIntHeader("Age"); } @Test public void itWrapsAFailedParseInABadRequestException() throws Exception { when(servletRequest.getIntHeader("Age")).thenThrow(new IllegalArgumentException("no")); try { request.getIntHeader("Age"); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); assertThat(e.getCause(), is(IllegalArgumentException.class)); } } } public static class Getting_The_Server_Name extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itNormalizesTheServerName() throws Exception { when(servletRequest.getServerName()).thenReturn("example.com"); assertThat(request.getServerName(), is("example.com")); verify(servletRequest).getServerName(); } @Test public void itThrowsABadRequestExceptionIfTheSchemeIsInvalid() throws Exception { when(servletRequest.getServerName()).thenReturn("blah\0.com"); try { request.getServerName(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Request_Dispatcher extends Context { private RequestDispatcher dispatcher; @Before @Override public void setup() throws Exception { super.setup(); this.dispatcher = mock(RequestDispatcher.class); } @Test public void itPassesThroughIfPathStartsWithWebInf() throws Exception { when(servletRequest.getRequestDispatcher("WEB-INF/thing")).thenReturn(dispatcher); assertThat(request.getRequestDispatcher("WEB-INF/thing"), is(dispatcher)); verify(servletRequest).getRequestDispatcher("WEB-INF/thing"); } @Test public void itReturnsNullIfPathDoesNotStartWithWebInf() throws Exception { when(servletRequest.getRequestDispatcher(anyString())).thenReturn(dispatcher); assertThat(request.getRequestDispatcher("../WEB-INF/thing"), is(nullValue())); verify(servletRequest, never()).getRequestDispatcher(anyString()); } } public static class Getting_A_List_Of_Header_Names extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itEnumeratesValidHeaders() throws Exception { when(servletRequest.getHeaderNames()).thenReturn(Collections.enumeration(ImmutableList.of("Accept", "User-Agent"))); assertThat(enumerationToList(request.getHeaderNames()), hasItems("Accept", "User-Agent")); } @Test public void itThrowsABadRequestExceptionOnMalformedHeaders() throws Exception { when(servletRequest.getHeaderNames()).thenReturn(Collections.enumeration(ImmutableList.of("Accept", "Age\0DEATH"))); try { request.getHeaderNames(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_A_Header_Value extends Context { @Before @Override public void setup() throws Exception { super.setup(); when(servletRequest.getHeader("Accept")).thenReturn("application/json"); when(servletRequest.getHeader("User-Agent")).thenReturn("MAL\0\0ICE"); } @Test public void itPassesValidHeadersStraightThrough() throws Exception { assertThat(request.getHeader("Accept"), is("application/json")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { try { request.getHeader("User-Agent"); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } @Test public void itThrowsAnIllegalArgumentExceptionWhenAskedForTheValueOfAMalformedHeader() throws Exception { try { request.getHeader("User-Agent\0"); fail("should have thrown an IllegalArgumentException, but didn't"); } catch (IllegalArgumentException e) { assertTrue(true); } } } public static class Getting_A_List_Of_Header_Values extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itEnumeratesValidHeaderValues() throws Exception { when(servletRequest.getHeaders("Accept")).thenReturn(Collections.enumeration(ImmutableList.of("application/json", "application/xml"))); assertThat(enumerationToList(request.getHeaders("Accept")), hasItems("application/json", "application/xml")); } @Test public void itThrowsABadRequestExceptionOnMalformedHeaderValues() throws Exception { when(servletRequest.getHeaders("Accept")).thenReturn(Collections.enumeration(ImmutableList.of("application/json", "Age\0DEATH"))); try { request.getHeaders("Accept"); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } @Test public void itThrowsAnIllegalArgumentExceptionWhenAskedForTheValuesOfAMalformedHeader() throws Exception { try { request.getHeaders("User-Agent\0"); fail("should have thrown an IllegalArgumentException, but didn't"); } catch (IllegalArgumentException e) { assertTrue(true); } } } public static class Getting_A_List_Of_Cookies extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itReturnsAnEmptyArrayIsCookiesAreNull() throws Exception { when(servletRequest.getCookies()).thenReturn(null); assertThat(request.getCookies().length, is(0)); } @Test public void itReturnsAnArrayOfValidCookies() throws Exception { when(servletRequest.getCookies()).thenReturn(new Cookie[] { new Cookie("sessionid", "blorp") }); assertThat(request.getCookies()[0].getName(), is("sessionid")); assertThat(request.getCookies()[0].getValue(), is("blorp")); } @Test public void itThrowsABadRequestExceptionWithInvalidCookies() throws Exception { final Cookie badCookie = mock(Cookie.class); when(badCookie.getName()).thenReturn("\0\0"); when(servletRequest.getCookies()).thenReturn(new Cookie[] { badCookie }); try { request.getCookies(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Request_URI extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itNormalizesTheURI() throws Exception { when(servletRequest.getRequestURI()).thenReturn("/blah%7d"); assertThat(request.getRequestURI(), is("/blah%7D")); } @Test public void itThrowsABadRequestExceptionWithAnInvalidRequestURI() throws Exception { when(servletRequest.getRequestURI()).thenReturn("/blah%ee"); try { request.getRequestURI(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_QueryString extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itPassesNullThrough() throws Exception { when(servletRequest.getQueryString()).thenReturn(null); assertThat(request.getQueryString(), is(nullValue())); } @Test public void itNormalizesTheQueryString() throws Exception { when(servletRequest.getQueryString()).thenReturn("j=blah%7d"); assertThat(request.getQueryString(), is("j=blah%7D")); } @Test public void itThrowsABadRequestExceptionWithAnInvalidQueryString() throws Exception { when(servletRequest.getQueryString()).thenReturn("j=blah%ee"); try { request.getQueryString(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_A_List_Of_Param_Names extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itEnumeratesValidHeaders() throws Exception { when(servletRequest.getParameterNames()).thenReturn(Collections.enumeration(ImmutableList.of("dingo", "woo"))); assertThat(enumerationToList(request.getParameterNames()), hasItems("dingo", "woo")); } @Test public void itThrowsABadRequestExceptionOnMalformedHeaders() throws Exception { when(servletRequest.getParameterNames()).thenReturn(Collections.enumeration(ImmutableList.of("dingo", "poison\0"))); try { request.getParameterNames(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_A_Param_Value extends Context { @Before @Override public void setup() throws Exception { super.setup(); when(servletRequest.getParameter("dingo")).thenReturn("woo"); when(servletRequest.getParameter("malice")).thenReturn("MAL\0\0ICE"); } @Test public void itPassesValidParametersStraightThrough() throws Exception { assertThat(request.getParameter("dingo"), is("woo")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { try { request.getParameter("malice"); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } @Test public void itThrowsAnIllegalArgumentExceptionWhenAskedForTheValueOfAMalformedParamName() throws Exception { try { request.getParameter("weird\0"); fail("should have thrown an IllegalArgumentException, but didn't"); } catch (IllegalArgumentException e) { assertTrue(true); } } } public static class Getting_An_Array_Of_Param_Values extends Context { @Before @Override public void setup() throws Exception { super.setup(); when(servletRequest.getParameterValues("dingo")).thenReturn(new String[] { "woo" }); when(servletRequest.getParameterValues("malice")).thenReturn(new String[] { "MAL\0\0ICE" }); } @Test public void itPassesValidParametersStraightThrough() throws Exception { assertThat(request.getParameterValues("dingo").length, is(1)); assertThat(request.getParameterValues("dingo")[0], is("woo")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { try { request.getParameterValues("malice"); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } @Test public void itThrowsAnIllegalArgumentExceptionWhenAskedForTheValueOfAMalformedParamName() throws Exception { try { request.getParameterValues("weird\0"); fail("should have thrown an IllegalArgumentException, but didn't"); } catch (IllegalArgumentException e) { assertTrue(true); } } } public static class Getting_A_Map_Of_Param_Values extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itPassesValidParametersStraightThrough() throws Exception { when(servletRequest.getParameterMap()).thenReturn(ImmutableMap.of("dingo", new String[] { "woo" })); assertThat(request.getParameterMap().keySet(), hasItem("dingo")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { when(servletRequest.getParameterMap()).thenReturn(ImmutableMap.of("malice", new String[] { "MAL\0ICE" })); try { request.getParameterMap(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Requested_Session_Id extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itPassesValidSessionIdsThrough() throws Exception { when(servletRequest.getRequestedSessionId()).thenReturn("AHAHAHAHAHAHAHA"); assertThat(request.getRequestedSessionId(), is("AHAHAHAHAHAHAHA")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { when(servletRequest.getRequestedSessionId()).thenReturn("AHAHAHAH\0\0AHAHAHA"); try { request.getRequestedSessionId(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Path_Info extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itPassesValidPathsThrough() throws Exception { when(servletRequest.getPathInfo()).thenReturn("whee"); assertThat(request.getPathInfo(), is("whee")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { when(servletRequest.getPathInfo()).thenReturn("AHAHAHAH\0\0AHAHAHA"); try { request.getPathInfo(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_The_Context_Path extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itPassesValidPathsThrough() throws Exception { when(servletRequest.getContextPath()).thenReturn("whee"); assertThat(request.getContextPath(), is("whee")); } @Test public void itThrowsABadRequestExceptionOnInvalidValues() throws Exception { when(servletRequest.getContextPath()).thenReturn("AHAHAHAH\0\0AHAHAHA"); try { request.getContextPath(); fail("should have thrown a BadRequestException, but didn't"); } catch (BadRequestException e) { assertThat(e.getBadRequest(), is(sameInstance(servletRequest))); } } } public static class Getting_A_Human_Readable_Representation extends Context { @Before @Override public void setup() throws Exception { super.setup(); } @Test public void itPassesThrough() throws Exception { assertThat(request.toString(), is(servletRequest.toString())); } } }