/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.embedded.undertow;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLHandshakeException;
import io.undertow.Undertow.Builder;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
import io.undertow.servlet.api.ServletContainer;
import org.apache.jasper.servlet.JspServlet;
import org.junit.Test;
import org.mockito.InOrder;
import org.springframework.boot.web.server.ErrorPage;
import org.springframework.boot.web.server.MimeMappings.Mapping;
import org.springframework.boot.web.server.PortInUseException;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactory;
import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactoryTests;
import org.springframework.boot.web.servlet.server.ExampleServlet;
import org.springframework.http.HttpStatus;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link UndertowServletWebServerFactory}.
*
* @author Ivan Sopov
* @author Andy Wilkinson
*/
public class UndertowServletWebServerFactoryTests
extends AbstractServletWebServerFactoryTests {
@Override
protected UndertowServletWebServerFactory getFactory() {
return new UndertowServletWebServerFactory(0);
}
@Test
public void errorPage404() throws Exception {
AbstractServletWebServerFactory factory = getFactory();
factory.addErrorPages(new ErrorPage(HttpStatus.NOT_FOUND, "/hello"));
this.webServer = factory.getWebServer(
new ServletRegistrationBean<>(new ExampleServlet(), "/hello"));
this.webServer.start();
assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("Hello World");
assertThat(getResponse(getLocalUrl("/not-found"))).isEqualTo("Hello World");
}
@Test
public void setNullBuilderCustomizersThrows() {
UndertowServletWebServerFactory factory = getFactory();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Customizers must not be null");
factory.setBuilderCustomizers(null);
}
@Test
public void addNullAddBuilderCustomizersThrows() {
UndertowServletWebServerFactory factory = getFactory();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Customizers must not be null");
factory.addBuilderCustomizers((UndertowBuilderCustomizer[]) null);
}
@Test
public void builderCustomizers() throws Exception {
UndertowServletWebServerFactory factory = getFactory();
UndertowBuilderCustomizer[] customizers = new UndertowBuilderCustomizer[4];
for (int i = 0; i < customizers.length; i++) {
customizers[i] = mock(UndertowBuilderCustomizer.class);
}
factory.setBuilderCustomizers(Arrays.asList(customizers[0], customizers[1]));
factory.addBuilderCustomizers(customizers[2], customizers[3]);
this.webServer = factory.getWebServer();
InOrder ordered = inOrder((Object[]) customizers);
for (UndertowBuilderCustomizer customizer : customizers) {
ordered.verify(customizer).customize((Builder) any());
}
}
@Test
public void setNullDeploymentInfoCustomizersThrows() {
UndertowServletWebServerFactory factory = getFactory();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Customizers must not be null");
factory.setDeploymentInfoCustomizers(null);
}
@Test
public void addNullAddDeploymentInfoCustomizersThrows() {
UndertowServletWebServerFactory factory = getFactory();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Customizers must not be null");
factory.addDeploymentInfoCustomizers((UndertowDeploymentInfoCustomizer[]) null);
}
@Test
public void deploymentInfo() throws Exception {
UndertowServletWebServerFactory factory = getFactory();
UndertowDeploymentInfoCustomizer[] customizers = new UndertowDeploymentInfoCustomizer[4];
for (int i = 0; i < customizers.length; i++) {
customizers[i] = mock(UndertowDeploymentInfoCustomizer.class);
}
factory.setDeploymentInfoCustomizers(
Arrays.asList(customizers[0], customizers[1]));
factory.addDeploymentInfoCustomizers(customizers[2], customizers[3]);
this.webServer = factory.getWebServer();
InOrder ordered = inOrder((Object[]) customizers);
for (UndertowDeploymentInfoCustomizer customizer : customizers) {
ordered.verify(customizer).customize((DeploymentInfo) any());
}
}
@Test
public void basicSslClasspathKeyStore() throws Exception {
testBasicSslWithKeyStore("classpath:test.jks");
}
@Test
public void defaultContextPath() throws Exception {
UndertowServletWebServerFactory factory = getFactory();
final AtomicReference<String> contextPath = new AtomicReference<>();
factory.addDeploymentInfoCustomizers(new UndertowDeploymentInfoCustomizer() {
@Override
public void customize(DeploymentInfo deploymentInfo) {
contextPath.set(deploymentInfo.getContextPath());
}
});
this.webServer = factory.getWebServer();
assertThat(contextPath.get()).isEqualTo("/");
}
@Test
public void useForwardHeaders() throws Exception {
UndertowServletWebServerFactory factory = getFactory();
factory.setUseForwardHeaders(true);
assertForwardHeaderIsUsed(factory);
}
@Test
public void eachFactoryUsesADiscreteServletContainer() {
assertThat(getServletContainerFromNewFactory())
.isNotEqualTo(getServletContainerFromNewFactory());
}
@Test
public void accessLogCanBeEnabled()
throws IOException, URISyntaxException, InterruptedException {
testAccessLog(null, null, "access_log.log");
}
@Test
public void accessLogCanBeCustomized()
throws IOException, URISyntaxException, InterruptedException {
testAccessLog("my_access.", "logz", "my_access.logz");
}
private void testAccessLog(String prefix, String suffix, String expectedFile)
throws IOException, URISyntaxException, InterruptedException {
UndertowServletWebServerFactory factory = getFactory();
factory.setAccessLogEnabled(true);
factory.setAccessLogPrefix(prefix);
factory.setAccessLogSuffix(suffix);
File accessLogDirectory = this.temporaryFolder.getRoot();
factory.setAccessLogDirectory(accessLogDirectory);
assertThat(accessLogDirectory.listFiles()).isEmpty();
this.webServer = factory.getWebServer(
new ServletRegistrationBean<>(new ExampleServlet(), "/hello"));
this.webServer.start();
assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("Hello World");
File accessLog = new File(accessLogDirectory, expectedFile);
awaitFile(accessLog);
assertThat(accessLogDirectory.listFiles()).contains(accessLog);
}
@Override
protected void addConnector(final int port, AbstractServletWebServerFactory factory) {
((UndertowServletWebServerFactory) factory)
.addBuilderCustomizers(new UndertowBuilderCustomizer() {
@Override
public void customize(Builder builder) {
builder.addHttpListener(port, "0.0.0.0");
}
});
}
@Test(expected = SSLHandshakeException.class)
public void sslRestrictedProtocolsEmptyCipherFailure() throws Exception {
testRestrictedSSLProtocolsAndCipherSuites(new String[] { "TLSv1.2" },
new String[] { "TLS_EMPTY_RENEGOTIATION_INFO_SCSV" });
}
@Test(expected = SSLHandshakeException.class)
public void sslRestrictedProtocolsECDHETLS1Failure() throws Exception {
testRestrictedSSLProtocolsAndCipherSuites(new String[] { "TLSv1" },
new String[] { "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256" });
}
@Test
public void sslRestrictedProtocolsECDHESuccess() throws Exception {
testRestrictedSSLProtocolsAndCipherSuites(new String[] { "TLSv1.2" },
new String[] { "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256" });
}
@Test
public void sslRestrictedProtocolsRSATLS12Success() throws Exception {
testRestrictedSSLProtocolsAndCipherSuites(new String[] { "TLSv1.2" },
new String[] { "TLS_RSA_WITH_AES_128_CBC_SHA256" });
}
@Test(expected = SSLHandshakeException.class)
public void sslRestrictedProtocolsRSATLS11Failure() throws Exception {
testRestrictedSSLProtocolsAndCipherSuites(new String[] { "TLSv1.1" },
new String[] { "TLS_RSA_WITH_AES_128_CBC_SHA256" });
}
@Override
protected JspServlet getJspServlet() {
return null; // Undertow does not support JSPs
}
private void awaitFile(File file) throws InterruptedException {
long end = System.currentTimeMillis() + 10000;
while (!file.exists() && System.currentTimeMillis() < end) {
Thread.sleep(100);
}
}
private ServletContainer getServletContainerFromNewFactory() {
UndertowServletWebServer container = (UndertowServletWebServer) getFactory()
.getWebServer();
try {
return ((DeploymentManager) ReflectionTestUtils.getField(container,
"manager")).getDeployment().getServletContainer();
}
finally {
container.stop();
}
}
@Override
protected Map<String, String> getActualMimeMappings() {
return ((DeploymentManager) ReflectionTestUtils.getField(this.webServer,
"manager")).getDeployment().getMimeExtensionMappings();
}
@Override
protected Collection<Mapping> getExpectedMimeMappings() {
// Unlike Tomcat and Jetty, Undertow performs a case-sensitive match on file
// extension so it has a mapping for "z" and "Z".
Set<Mapping> expectedMappings = new HashSet<>(super.getExpectedMimeMappings());
expectedMappings.add(new Mapping("Z", "application/x-compress"));
return expectedMappings;
}
@Override
protected Charset getCharset(Locale locale) {
DeploymentInfo info = ((DeploymentManager) ReflectionTestUtils
.getField(this.webServer, "manager")).getDeployment().getDeploymentInfo();
String charsetName = info.getLocaleCharsetMapping().get(locale.toString());
return (charsetName != null) ? Charset.forName(charsetName) : null;
}
@Override
protected void handleExceptionCausedByBlockedPort(RuntimeException ex,
int blockedPort) {
assertThat(ex).isInstanceOf(PortInUseException.class);
assertThat(((PortInUseException) ex).getPort()).isEqualTo(blockedPort);
}
}