/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.context;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.EnumSet;
import java.util.Properties;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletContextListener;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.MockitoAnnotations;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.config.Scope;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.web.context.ServerPortInfoApplicationContextInitializer;
import org.springframework.boot.web.servlet.DelegatingFilterProxyRegistrationBean;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.server.MockServletWebServerFactory;
import org.springframework.context.ApplicationContextException;
import org.springframework.context.ApplicationListener;
import org.springframework.context.support.AbstractApplicationContext;
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.SessionScope;
import org.springframework.web.filter.GenericFilterBean;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.withSettings;
/**
* Tests for {@link ServletWebServerApplicationContext}.
*
* @author Phillip Webb
* @author Stephane Nicoll
*/
public class ServletWebServerApplicationContextTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private ServletWebServerApplicationContext context;
@Captor
private ArgumentCaptor<Filter> filterCaptor;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.context = new ServletWebServerApplicationContext();
}
@After
public void cleanup() {
this.context.close();
}
@Test
public void startRegistrations() throws Exception {
addWebServerFactoryBean();
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
// Ensure that the context has been setup
assertThat(this.context.getServletContext())
.isEqualTo(factory.getServletContext());
verify(factory.getServletContext()).setAttribute(
WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE,
this.context);
// Ensure WebApplicationContextUtils.registerWebApplicationScopes was called
assertThat(this.context.getBeanFactory()
.getRegisteredScope(WebApplicationContext.SCOPE_SESSION))
.isInstanceOf(SessionScope.class);
// Ensure WebApplicationContextUtils.registerEnvironmentBeans was called
assertThat(this.context
.containsBean(WebApplicationContext.SERVLET_CONTEXT_BEAN_NAME)).isTrue();
}
@Test
public void doesNotRegistersShutdownHook() throws Exception {
// See gh-314 for background. We no longer register the shutdown hook
// since it is really the callers responsibility. The shutdown hook could
// also be problematic in a classic WAR deployment.
addWebServerFactoryBean();
this.context.refresh();
Field shutdownHookField = AbstractApplicationContext.class
.getDeclaredField("shutdownHook");
shutdownHookField.setAccessible(true);
Object shutdownHook = shutdownHookField.get(this.context);
assertThat(shutdownHook).isNull();
}
@Test
public void ServletWebServerInitializedEventPublished() throws Exception {
addWebServerFactoryBean();
this.context.registerBeanDefinition("listener",
new RootBeanDefinition(MockListener.class));
this.context.refresh();
ServletWebServerInitializedEvent event = this.context.getBean(MockListener.class)
.getEvent();
assertThat(event).isNotNull();
assertThat(event.getSource().getPort() >= 0).isTrue();
assertThat(event.getApplicationContext()).isEqualTo(this.context);
}
@Test
public void localPortIsAvailable() throws Exception {
addWebServerFactoryBean();
new ServerPortInfoApplicationContextInitializer().initialize(this.context);
this.context.refresh();
ConfigurableEnvironment environment = this.context.getEnvironment();
assertThat(environment.containsProperty("local.server.port")).isTrue();
assertThat(environment.getProperty("local.server.port")).isEqualTo("8080");
}
@Test
public void stopOnClose() throws Exception {
addWebServerFactoryBean();
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
this.context.close();
verify(factory.getWebServer()).stop();
}
@Test
public void cannotSecondRefresh() throws Exception {
addWebServerFactoryBean();
this.context.refresh();
this.thrown.expect(IllegalStateException.class);
this.context.refresh();
}
@Test
public void servletContextAwareBeansAreInjected() throws Exception {
addWebServerFactoryBean();
ServletContextAware bean = mock(ServletContextAware.class);
this.context.registerBeanDefinition("bean", beanDefinition(bean));
this.context.refresh();
verify(bean).setServletContext(getWebServerFactory().getServletContext());
}
@Test
public void missingServletWebServerFactory() throws Exception {
this.thrown.expect(ApplicationContextException.class);
this.thrown.expectMessage(
"Unable to start ServletWebServerApplicationContext due to missing "
+ "ServletWebServerFactory bean");
this.context.refresh();
}
@Test
public void tooManyWebServerFactories() throws Exception {
addWebServerFactoryBean();
this.context.registerBeanDefinition("webServerFactory2",
new RootBeanDefinition(MockServletWebServerFactory.class));
this.thrown.expect(ApplicationContextException.class);
this.thrown.expectMessage(
"Unable to start ServletWebServerApplicationContext due to "
+ "multiple ServletWebServerFactory beans");
this.context.refresh();
}
@Test
public void singleServletBean() throws Exception {
addWebServerFactoryBean();
Servlet servlet = mock(Servlet.class);
this.context.registerBeanDefinition("servletBean", beanDefinition(servlet));
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
verify(factory.getServletContext()).addServlet("servletBean", servlet);
verify(factory.getRegisteredServlet(0).getRegistration()).addMapping("/");
}
@Test
public void orderedBeanInsertedCorrectly() throws Exception {
addWebServerFactoryBean();
OrderedFilter filter = new OrderedFilter();
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
FilterRegistrationBean<Filter> registration = new FilterRegistrationBean<>();
registration.setFilter(mock(Filter.class));
registration.setOrder(100);
this.context.registerBeanDefinition("filterRegistrationBean",
beanDefinition(registration));
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
verify(factory.getServletContext()).addFilter("filterBean", filter);
verify(factory.getServletContext()).addFilter("object", registration.getFilter());
assertThat(factory.getRegisteredFilter(0).getFilter()).isEqualTo(filter);
}
@Test
public void multipleServletBeans() throws Exception {
addWebServerFactoryBean();
Servlet servlet1 = mock(Servlet.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) servlet1).getOrder()).willReturn(1);
Servlet servlet2 = mock(Servlet.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) servlet2).getOrder()).willReturn(2);
this.context.registerBeanDefinition("servletBean2", beanDefinition(servlet2));
this.context.registerBeanDefinition("servletBean1", beanDefinition(servlet1));
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
ServletContext servletContext = factory.getServletContext();
InOrder ordered = inOrder(servletContext);
ordered.verify(servletContext).addServlet("servletBean1", servlet1);
ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(factory.getRegisteredServlet(0).getRegistration())
.addMapping("/servletBean1/");
verify(factory.getRegisteredServlet(1).getRegistration())
.addMapping("/servletBean2/");
}
@Test
public void multipleServletBeansWithMainDispatcher() throws Exception {
addWebServerFactoryBean();
Servlet servlet1 = mock(Servlet.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) servlet1).getOrder()).willReturn(1);
Servlet servlet2 = mock(Servlet.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) servlet2).getOrder()).willReturn(2);
this.context.registerBeanDefinition("servletBean2", beanDefinition(servlet2));
this.context.registerBeanDefinition("dispatcherServlet",
beanDefinition(servlet1));
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
ServletContext servletContext = factory.getServletContext();
InOrder ordered = inOrder(servletContext);
ordered.verify(servletContext).addServlet("dispatcherServlet", servlet1);
ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(factory.getRegisteredServlet(0).getRegistration()).addMapping("/");
verify(factory.getRegisteredServlet(1).getRegistration())
.addMapping("/servletBean2/");
}
@Test
public void servletAndFilterBeans() throws Exception {
addWebServerFactoryBean();
Servlet servlet = mock(Servlet.class);
Filter filter1 = mock(Filter.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) filter1).getOrder()).willReturn(1);
Filter filter2 = mock(Filter.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) filter2).getOrder()).willReturn(2);
this.context.registerBeanDefinition("servletBean", beanDefinition(servlet));
this.context.registerBeanDefinition("filterBean2", beanDefinition(filter2));
this.context.registerBeanDefinition("filterBean1", beanDefinition(filter1));
this.context.refresh();
MockServletWebServerFactory factory = getWebServerFactory();
ServletContext servletContext = factory.getServletContext();
InOrder ordered = inOrder(servletContext);
verify(factory.getServletContext()).addServlet("servletBean", servlet);
verify(factory.getRegisteredServlet(0).getRegistration()).addMapping("/");
ordered.verify(factory.getServletContext()).addFilter("filterBean1", filter1);
ordered.verify(factory.getServletContext()).addFilter("filterBean2", filter2);
verify(factory.getRegisteredFilter(0).getRegistration()).addMappingForUrlPatterns(
EnumSet.of(DispatcherType.REQUEST), false, "/*");
verify(factory.getRegisteredFilter(1).getRegistration()).addMappingForUrlPatterns(
EnumSet.of(DispatcherType.REQUEST), false, "/*");
}
@Test
public void servletContextInitializerBeans() throws Exception {
addWebServerFactoryBean();
ServletContextInitializer initializer1 = mock(ServletContextInitializer.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) initializer1).getOrder()).willReturn(1);
ServletContextInitializer initializer2 = mock(ServletContextInitializer.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) initializer2).getOrder()).willReturn(2);
this.context.registerBeanDefinition("initializerBean2",
beanDefinition(initializer2));
this.context.registerBeanDefinition("initializerBean1",
beanDefinition(initializer1));
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
InOrder ordered = inOrder(initializer1, initializer2);
ordered.verify(initializer1).onStartup(servletContext);
ordered.verify(initializer2).onStartup(servletContext);
}
@Test
public void servletContextListenerBeans() throws Exception {
addWebServerFactoryBean();
ServletContextListener initializer = mock(ServletContextListener.class);
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
verify(servletContext).addListener(initializer);
}
@Test
public void unorderedServletContextInitializerBeans() throws Exception {
addWebServerFactoryBean();
ServletContextInitializer initializer1 = mock(ServletContextInitializer.class);
ServletContextInitializer initializer2 = mock(ServletContextInitializer.class);
this.context.registerBeanDefinition("initializerBean2",
beanDefinition(initializer2));
this.context.registerBeanDefinition("initializerBean1",
beanDefinition(initializer1));
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
verify(initializer1).onStartup(servletContext);
verify(initializer2).onStartup(servletContext);
}
@Test
public void servletContextInitializerBeansDoesNotSkipServletsAndFilters()
throws Exception {
addWebServerFactoryBean();
ServletContextInitializer initializer = mock(ServletContextInitializer.class);
Servlet servlet = mock(Servlet.class);
Filter filter = mock(Filter.class);
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
this.context.registerBeanDefinition("servletBean", beanDefinition(servlet));
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
verify(initializer).onStartup(servletContext);
verify(servletContext).addServlet(anyString(), (Servlet) any());
verify(servletContext).addFilter(anyString(), (Filter) any());
}
@Test
public void servletContextInitializerBeansSkipsRegisteredServletsAndFilters()
throws Exception {
addWebServerFactoryBean();
Servlet servlet = mock(Servlet.class);
Filter filter = mock(Filter.class);
ServletRegistrationBean<Servlet> initializer = new ServletRegistrationBean<>(
servlet, "/foo");
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
this.context.registerBeanDefinition("servletBean", beanDefinition(servlet));
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
verify(servletContext, atMost(1)).addServlet(anyString(), (Servlet) any());
verify(servletContext, atMost(1)).addFilter(anyString(), (Filter) any());
}
@Test
public void filterRegistrationBeansSkipsRegisteredFilters() throws Exception {
addWebServerFactoryBean();
Filter filter = mock(Filter.class);
FilterRegistrationBean<Filter> initializer = new FilterRegistrationBean<>(filter);
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
verify(servletContext, atMost(1)).addFilter(anyString(), (Filter) any());
}
@Test
public void delegatingFilterProxyRegistrationBeansSkipsTargetBeanNames()
throws Exception {
addWebServerFactoryBean();
DelegatingFilterProxyRegistrationBean initializer = new DelegatingFilterProxyRegistrationBean(
"filterBean");
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
BeanDefinition filterBeanDefinition = beanDefinition(
new IllegalStateException("Create FilterBean Failure"));
filterBeanDefinition.setLazyInit(true);
this.context.registerBeanDefinition("filterBean", filterBeanDefinition);
this.context.refresh();
ServletContext servletContext = getWebServerFactory().getServletContext();
verify(servletContext, atMost(1)).addFilter(anyString(),
this.filterCaptor.capture());
// Up to this point the filterBean should not have been created, calling
// the delegate proxy will trigger creation and an exception
this.thrown.expect(BeanCreationException.class);
this.thrown.expectMessage("Create FilterBean Failure");
this.filterCaptor.getValue().init(new MockFilterConfig());
this.filterCaptor.getValue().doFilter(new MockHttpServletRequest(),
new MockHttpServletResponse(), new MockFilterChain());
}
@Test
public void postProcessWebServerFactory() throws Exception {
RootBeanDefinition beanDefinition = new RootBeanDefinition(
MockServletWebServerFactory.class);
MutablePropertyValues pv = new MutablePropertyValues();
pv.add("port", "${port}");
beanDefinition.setPropertyValues(pv);
this.context.registerBeanDefinition("webServerFactory", beanDefinition);
PropertySourcesPlaceholderConfigurer propertySupport = new PropertySourcesPlaceholderConfigurer();
Properties properties = new Properties();
properties.put("port", 8080);
propertySupport.setProperties(properties);
this.context.registerBeanDefinition("propertySupport",
beanDefinition(propertySupport));
this.context.refresh();
assertThat(getWebServerFactory().getWebServer().getPort()).isEqualTo(8080);
}
@Test
public void doesNotReplaceExistingScopes() throws Exception { // gh-2082
Scope scope = mock(Scope.class);
ConfigurableListableBeanFactory factory = this.context.getBeanFactory();
factory.registerScope(WebApplicationContext.SCOPE_REQUEST, scope);
factory.registerScope(WebApplicationContext.SCOPE_SESSION, scope);
addWebServerFactoryBean();
this.context.refresh();
assertThat(factory.getRegisteredScope(WebApplicationContext.SCOPE_REQUEST))
.isSameAs(scope);
assertThat(factory.getRegisteredScope(WebApplicationContext.SCOPE_SESSION))
.isSameAs(scope);
}
private void addWebServerFactoryBean() {
this.context.registerBeanDefinition("webServerFactory",
new RootBeanDefinition(MockServletWebServerFactory.class));
}
public MockServletWebServerFactory getWebServerFactory() {
return this.context.getBean(MockServletWebServerFactory.class);
}
private BeanDefinition beanDefinition(Object bean) {
RootBeanDefinition beanDefinition = new RootBeanDefinition();
beanDefinition.setBeanClass(getClass());
beanDefinition.setFactoryMethodName("getBean");
ConstructorArgumentValues constructorArguments = new ConstructorArgumentValues();
constructorArguments.addGenericArgumentValue(bean);
beanDefinition.setConstructorArgumentValues(constructorArguments);
return beanDefinition;
}
public static <T> T getBean(T object) {
if (object instanceof RuntimeException) {
throw (RuntimeException) object;
}
return object;
}
public static class MockListener
implements ApplicationListener<ServletWebServerInitializedEvent> {
private ServletWebServerInitializedEvent event;
@Override
public void onApplicationEvent(ServletWebServerInitializedEvent event) {
this.event = event;
}
public ServletWebServerInitializedEvent getEvent() {
return this.event;
}
}
@Order(10)
protected static class OrderedFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
}
}
}