package io.dropwizard.jersey.filter; import com.codahale.metrics.MetricRegistry; import com.google.common.collect.ImmutableMap; import io.dropwizard.jersey.AbstractJerseyTest; import io.dropwizard.jersey.DropwizardResourceConfig; import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.servlet.ServletProperties; import org.glassfish.jersey.test.DeploymentContext; import org.glassfish.jersey.test.ServletDeploymentContext; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.spi.TestContainerException; import org.glassfish.jersey.test.spi.TestContainerFactory; import org.junit.Before; import org.junit.Test; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AllowedMethodsFilterTest extends AbstractJerseyTest { private static final int DISALLOWED_STATUS_CODE = Response.Status.METHOD_NOT_ALLOWED.getStatusCode(); private static final int OK_STATUS_CODE = Response.Status.OK.getStatusCode(); private final HttpServletRequest request = mock(HttpServletRequest.class); private final HttpServletResponse response = mock(HttpServletResponse.class); private final FilterChain chain = mock(FilterChain.class); private final FilterConfig config = mock(FilterConfig.class); private final AllowedMethodsFilter filter = new AllowedMethodsFilter(); @Before public void setUpFilter() { filter.init(config); } @Override protected TestContainerFactory getTestContainerFactory() throws TestContainerException { return new GrizzlyWebTestContainerFactory(); } @Override protected DeploymentContext configureDeployment() { final ResourceConfig rc = DropwizardResourceConfig.forTesting(new MetricRegistry()); final Map<String, String> filterParams = ImmutableMap.of( AllowedMethodsFilter.ALLOWED_METHODS_PARAM, "GET,POST"); return ServletDeploymentContext.builder(rc) .addFilter(AllowedMethodsFilter.class, "allowedMethodsFilter", filterParams) .initParam(ServletProperties.JAXRS_APPLICATION_CLASS, DropwizardResourceConfig.class.getName()) .initParam(ServerProperties.PROVIDER_CLASSNAMES, DummyResource.class.getName()) .build(); } private int getResponseStatusForRequestMethod(String method, boolean includeEntity) { final Response resourceResponse = includeEntity ? target("/ping").request().method(method, Entity.entity("", MediaType.TEXT_PLAIN)) : target("/ping").request().method(method); try { return resourceResponse.getStatus(); } finally { resourceResponse.close(); } } @Test public void testGetRequestAllowed() { assertEquals(OK_STATUS_CODE, getResponseStatusForRequestMethod("GET", false)); } @Test public void testPostRequestAllowed() { assertEquals(OK_STATUS_CODE, getResponseStatusForRequestMethod("POST", true)); } @Test public void testPutRequestBlocked() { assertEquals(DISALLOWED_STATUS_CODE, getResponseStatusForRequestMethod("PUT", true)); } @Test public void testDeleteRequestBlocked() { assertEquals(DISALLOWED_STATUS_CODE, getResponseStatusForRequestMethod("DELETE", false)); } @Test public void testTraceRequestBlocked() { assertEquals(DISALLOWED_STATUS_CODE, getResponseStatusForRequestMethod("TRACE", false)); } @Test public void allowsAllowedMethod() throws Exception { when(request.getMethod()).thenReturn("GET"); filter.doFilter(request, response, chain); verify(chain).doFilter(request, response); } @Test public void blocksDisallowedMethod() throws Exception { when(request.getMethod()).thenReturn("TRACE"); filter.doFilter(request, response, chain); verify(chain, never()).doFilter(request, response); } @Test public void disallowedMethodCausesMethodNotAllowedResponse() throws IOException, ServletException { when(request.getMethod()).thenReturn("TRACE"); filter.doFilter(request, response, chain); verify(response).sendError(eq(DISALLOWED_STATUS_CODE)); } }