/* * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.servlet.support; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import javax.servlet.Servlet; import javax.servlet.ServletException; import javax.servlet.ServletRegistration; import org.junit.Test; import org.springframework.mock.web.test.MockServletContext; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; import static org.junit.Assert.*; /** * Test case for {@link AbstractDispatcherServletInitializer}. * * @author Arjen Poutsma */ public class DispatcherServletInitializerTests { private static final String SERVLET_NAME = "myservlet"; private static final String ROLE_NAME = "role"; private static final String SERVLET_MAPPING = "/myservlet"; private final MockServletContext servletContext = new MyMockServletContext(); private final AbstractDispatcherServletInitializer initializer = new MyDispatcherServletInitializer(); private final Map<String, Servlet> servlets = new LinkedHashMap<>(2); private final Map<String, MockServletRegistration> registrations = new LinkedHashMap<>(2); @Test public void register() throws ServletException { initializer.onStartup(servletContext); assertEquals(1, servlets.size()); assertNotNull(servlets.get(SERVLET_NAME)); DispatcherServlet servlet = (DispatcherServlet) servlets.get(SERVLET_NAME); assertEquals(MyDispatcherServlet.class, servlet.getClass()); WebApplicationContext servletContext = servlet.getWebApplicationContext(); assertTrue(servletContext.containsBean("bean")); assertTrue(servletContext.getBean("bean") instanceof MyBean); assertEquals(1, registrations.size()); assertNotNull(registrations.get(SERVLET_NAME)); MockServletRegistration registration = registrations.get(SERVLET_NAME); assertEquals(Collections.singleton(SERVLET_MAPPING), registration.getMappings()); assertEquals(1, registration.getLoadOnStartup()); assertEquals(ROLE_NAME, registration.getRunAsRole()); } private class MyMockServletContext extends MockServletContext { @Override public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) { servlets.put(servletName, servlet); MockServletRegistration registration = new MockServletRegistration(); registrations.put(servletName, registration); return registration; } } private static class MyDispatcherServletInitializer extends AbstractDispatcherServletInitializer { @Override protected String getServletName() { return SERVLET_NAME; } @Override protected DispatcherServlet createDispatcherServlet(WebApplicationContext servletAppContext) { return new MyDispatcherServlet(servletAppContext); } @Override protected WebApplicationContext createServletApplicationContext() { StaticWebApplicationContext servletContext = new StaticWebApplicationContext(); servletContext.registerSingleton("bean", MyBean.class); return servletContext; } @Override protected String[] getServletMappings() { return new String[] { SERVLET_MAPPING }; } @Override protected void customizeRegistration(ServletRegistration.Dynamic registration) { registration.setRunAsRole("role"); } @Override protected WebApplicationContext createRootApplicationContext() { return null; } } private static class MyBean { } @SuppressWarnings("serial") private static class MyDispatcherServlet extends DispatcherServlet { public MyDispatcherServlet(WebApplicationContext webApplicationContext) { super(webApplicationContext); } } }