/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.server;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.RequestDispatcher;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRegistration;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.boot.web.server.WebServer;
import org.springframework.boot.web.server.WebServerException;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
/**
* Mock {@link ServletWebServerFactory}.
*
* @author Phillip Webb
* @author Andy Wilkinson
*/
public class MockServletWebServerFactory extends AbstractServletWebServerFactory {
private MockServletWebServer webServer;
@Override
public WebServer getWebServer(ServletContextInitializer... initializers) {
this.webServer = spy(
new MockServletWebServer(mergeInitializers(initializers), getPort()));
return this.webServer;
}
public MockServletWebServer getWebServer() {
return this.webServer;
}
public ServletContext getServletContext() {
return getWebServer() == null ? null : getWebServer().servletContext;
}
public RegisteredServlet getRegisteredServlet(int index) {
return getWebServer() == null ? null
: getWebServer().getRegisteredServlets().get(index);
}
public RegisteredFilter getRegisteredFilter(int index) {
return getWebServer() == null ? null
: getWebServer().getRegisteredFilters().get(index);
}
public static class MockServletWebServer implements WebServer {
private ServletContext servletContext;
private final ServletContextInitializer[] initializers;
private final List<RegisteredServlet> registeredServlets = new ArrayList<>();
private final List<RegisteredFilter> registeredFilters = new ArrayList<>();
private final int port;
public MockServletWebServer(ServletContextInitializer[] initializers, int port) {
this.initializers = initializers;
this.port = port;
initialize();
}
private void initialize() {
try {
this.servletContext = mock(ServletContext.class);
given(this.servletContext.addServlet(anyString(), (Servlet) any()))
.willAnswer(new Answer<ServletRegistration.Dynamic>() {
@Override
public ServletRegistration.Dynamic answer(
InvocationOnMock invocation) throws Throwable {
RegisteredServlet registeredServlet = new RegisteredServlet(
(Servlet) invocation.getArguments()[1]);
MockServletWebServer.this.registeredServlets
.add(registeredServlet);
return registeredServlet.getRegistration();
}
});
given(this.servletContext.addFilter(anyString(), (Filter) any()))
.willAnswer(new Answer<FilterRegistration.Dynamic>() {
@Override
public FilterRegistration.Dynamic answer(
InvocationOnMock invocation) throws Throwable {
RegisteredFilter registeredFilter = new RegisteredFilter(
(Filter) invocation.getArguments()[1]);
MockServletWebServer.this.registeredFilters
.add(registeredFilter);
return registeredFilter.getRegistration();
}
});
final Map<String, String> initParameters = new HashMap<>();
given(this.servletContext.setInitParameter(anyString(), anyString()))
.will(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation)
throws Throwable {
initParameters.put(invocation.getArgument(0),
invocation.getArgument(1));
return null;
}
});
given(this.servletContext.getInitParameterNames())
.willReturn(Collections.enumeration(initParameters.keySet()));
given(this.servletContext.getInitParameter(anyString()))
.willAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation)
throws Throwable {
return initParameters.get(invocation.getArgument(0));
}
});
given(this.servletContext.getAttributeNames())
.willReturn(MockServletWebServer.<String>emptyEnumeration());
given(this.servletContext.getNamedDispatcher("default"))
.willReturn(mock(RequestDispatcher.class));
for (ServletContextInitializer initializer : this.initializers) {
initializer.onStartup(this.servletContext);
}
}
catch (ServletException ex) {
throw new RuntimeException(ex);
}
}
@SuppressWarnings("unchecked")
public static <T> Enumeration<T> emptyEnumeration() {
return (Enumeration<T>) EmptyEnumeration.EMPTY_ENUMERATION;
}
@Override
public void start() throws WebServerException {
}
@Override
public void stop() {
this.servletContext = null;
this.registeredServlets.clear();
}
public Servlet[] getServlets() {
Servlet[] servlets = new Servlet[this.registeredServlets.size()];
for (int i = 0; i < servlets.length; i++) {
servlets[i] = this.registeredServlets.get(i).getServlet();
}
return servlets;
}
public List<RegisteredServlet> getRegisteredServlets() {
return this.registeredServlets;
}
public List<RegisteredFilter> getRegisteredFilters() {
return this.registeredFilters;
}
@Override
public int getPort() {
return this.port;
}
private static class EmptyEnumeration<E> implements Enumeration<E> {
static final EmptyEnumeration<Object> EMPTY_ENUMERATION = new EmptyEnumeration<>();
@Override
public boolean hasMoreElements() {
return false;
}
@Override
public E nextElement() {
throw new NoSuchElementException();
}
}
}
public static class RegisteredServlet {
private final Servlet servlet;
private final ServletRegistration.Dynamic registration;
public RegisteredServlet(Servlet servlet) {
this.servlet = servlet;
this.registration = mock(ServletRegistration.Dynamic.class);
}
public ServletRegistration.Dynamic getRegistration() {
return this.registration;
}
public Servlet getServlet() {
return this.servlet;
}
}
public static class RegisteredFilter {
private final Filter filter;
private final FilterRegistration.Dynamic registration;
public RegisteredFilter(Filter filter) {
this.filter = filter;
this.registration = mock(FilterRegistration.Dynamic.class);
}
public FilterRegistration.Dynamic getRegistration() {
return this.registration;
}
public Filter getFilter() {
return this.filter;
}
}
}