/* * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.boot.web.servlet.server; import java.io.File; import java.io.FileInputStream; import java.io.FileWriter; import java.io.FilenameFilter; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.net.InetSocketAddress; import java.net.MalformedURLException; import java.net.ServerSocket; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.nio.charset.Charset; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.zip.GZIPInputStream; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import javax.servlet.GenericServlet; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; import org.apache.http.client.HttpClient; import org.apache.http.client.entity.InputStreamFactory; import org.apache.http.client.protocol.HttpClientContext; import org.apache.http.conn.ssl.SSLConnectionSocketFactory; import org.apache.http.conn.ssl.TrustSelfSignedStrategy; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.impl.client.HttpClients; import org.apache.http.protocol.HttpContext; import org.apache.http.ssl.SSLContextBuilder; import org.apache.http.ssl.TrustStrategy; import org.apache.jasper.EmbeddedServletOptions; import org.apache.jasper.servlet.JspServlet; import org.junit.After; import org.junit.Assume; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.rules.TemporaryFolder; import org.mockito.InOrder; import org.springframework.boot.ApplicationHome; import org.springframework.boot.ApplicationTemp; import org.springframework.boot.testutil.InternalOutputCapture; import org.springframework.boot.web.server.Compression; import org.springframework.boot.web.server.ErrorPage; import org.springframework.boot.web.server.MimeMappings; import org.springframework.boot.web.server.Ssl; import org.springframework.boot.web.server.Ssl.ClientAuth; import org.springframework.boot.web.server.SslStoreProvider; import org.springframework.boot.web.server.WebServer; import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.boot.web.servlet.ServletContextInitializer; import org.springframework.boot.web.servlet.ServletRegistrationBean; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.util.FileCopyUtils; import org.springframework.util.SocketUtils; import org.springframework.util.StreamUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; /** * Base for testing classes that extends {@link AbstractServletWebServerFactory}. * * @author Phillip Webb * @author Greg Turnquist * @author Andy Wilkinson */ public abstract class AbstractServletWebServerFactoryTests { @Rule public ExpectedException thrown = ExpectedException.none(); @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); @Rule public InternalOutputCapture output = new InternalOutputCapture(); protected WebServer webServer; private final HttpClientContext httpClientContext = HttpClientContext.create(); @After public void tearDown() { if (this.webServer != null) { try { this.webServer.stop(); } catch (Exception ex) { // Ignore } } } @Test public void startServlet() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("Hello World"); } @Test public void startCalledTwice() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); int port = this.webServer.getPort(); this.webServer.start(); assertThat(this.webServer.getPort()).isEqualTo(port); assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("Hello World"); assertThat(this.output.toString()).containsOnlyOnce("started on port"); } @Test public void stopCalledTwice() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); this.webServer.stop(); this.webServer.stop(); } @Test public void emptyServerWhenPortIsMinusOne() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setPort(-1); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(this.webServer.getPort()).isLessThan(0); // Jetty is -2 } @Test public void stopServlet() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); int port = this.webServer.getPort(); this.webServer.stop(); this.thrown.expect(IOException.class); String response = getResponse(getLocalUrl(port, "/hello")); throw new RuntimeException( "Unexpected response on port " + port + " : " + response); } @Test public void startServletAndFilter() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration(), new FilterRegistrationBean<>(new ExampleFilter())); this.webServer.start(); assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("[Hello World]"); } @Test public void startBlocksUntilReadyToServe() throws Exception { AbstractServletWebServerFactory factory = getFactory(); final Date[] date = new Date[1]; this.webServer = factory.getWebServer(new ServletContextInitializer() { @Override public void onStartup(ServletContext servletContext) throws ServletException { try { Thread.sleep(500); date[0] = new Date(); } catch (InterruptedException ex) { throw new ServletException(ex); } } }); this.webServer.start(); assertThat(date[0]).isNotNull(); } @Test public void loadOnStartAfterContextIsInitialized() throws Exception { AbstractServletWebServerFactory factory = getFactory(); final InitCountingServlet servlet = new InitCountingServlet(); this.webServer = factory.getWebServer(new ServletContextInitializer() { @Override public void onStartup(ServletContext servletContext) throws ServletException { servletContext.addServlet("test", servlet).setLoadOnStartup(1); } }); assertThat(servlet.getInitCount()).isEqualTo(0); this.webServer.start(); assertThat(servlet.getInitCount()).isEqualTo(1); } @Test public void specificPort() throws Exception { AbstractServletWebServerFactory factory = getFactory(); int specificPort = SocketUtils.findAvailableTcpPort(41000); factory.setPort(specificPort); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(getResponse("http://localhost:" + specificPort + "/hello")) .isEqualTo("Hello World"); assertThat(this.webServer.getPort()).isEqualTo(specificPort); } @Test public void specificContextRoot() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setContextPath("/say"); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); assertThat(getResponse(getLocalUrl("/say/hello"))).isEqualTo("Hello World"); } @Test public void contextPathMustStartWithSlash() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("ContextPath must start with '/' and not end with '/'"); getFactory().setContextPath("missingslash"); } @Test public void contextPathMustNotEndWithSlash() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage("ContextPath must start with '/' and not end with '/'"); getFactory().setContextPath("extraslash/"); } @Test public void contextRootPathMustNotBeSlash() throws Exception { this.thrown.expect(IllegalArgumentException.class); this.thrown.expectMessage( "Root ContextPath must be specified using an empty string"); getFactory().setContextPath("/"); } @Test public void multipleConfigurations() throws Exception { AbstractServletWebServerFactory factory = getFactory(); ServletContextInitializer[] initializers = new ServletContextInitializer[6]; for (int i = 0; i < initializers.length; i++) { initializers[i] = mock(ServletContextInitializer.class); } factory.setInitializers(Arrays.asList(initializers[2], initializers[3])); factory.addInitializers(initializers[4], initializers[5]); this.webServer = factory.getWebServer(initializers[0], initializers[1]); this.webServer.start(); InOrder ordered = inOrder((Object[]) initializers); for (ServletContextInitializer initializer : initializers) { ordered.verify(initializer).onStartup((ServletContext) any()); } } @Test public void documentRoot() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); this.webServer = factory.getWebServer(); this.webServer.start(); assertThat(getResponse(getLocalUrl("/test.txt"))).isEqualTo("test"); } @Test public void mimeType() throws Exception { FileCopyUtils.copy("test", new FileWriter(this.temporaryFolder.newFile("test.xxcss"))); AbstractServletWebServerFactory factory = getFactory(); factory.setDocumentRoot(this.temporaryFolder.getRoot()); MimeMappings mimeMappings = new MimeMappings(); mimeMappings.add("xxcss", "text/css"); factory.setMimeMappings(mimeMappings); this.webServer = factory.getWebServer(); this.webServer.start(); ClientHttpResponse response = getClientResponse(getLocalUrl("/test.xxcss")); assertThat(response.getHeaders().getContentType().toString()) .isEqualTo("text/css"); response.close(); } @Test public void errorPage() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.addErrorPages(new ErrorPage(HttpStatus.INTERNAL_SERVER_ERROR, "/hello")); this.webServer = factory.getWebServer(exampleServletRegistration(), errorServletRegistration()); this.webServer.start(); assertThat(getResponse(getLocalUrl("/hello"))).isEqualTo("Hello World"); assertThat(getResponse(getLocalUrl("/bang"))).isEqualTo("Hello World"); } @Test public void errorPageFromPutRequest() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.addErrorPages(new ErrorPage(HttpStatus.INTERNAL_SERVER_ERROR, "/hello")); this.webServer = factory.getWebServer(exampleServletRegistration(), errorServletRegistration()); this.webServer.start(); assertThat(getResponse(getLocalUrl("/hello"), HttpMethod.PUT)) .isEqualTo("Hello World"); assertThat(getResponse(getLocalUrl("/bang"), HttpMethod.PUT)) .isEqualTo("Hello World"); } @Test public void basicSslFromClassPath() throws Exception { testBasicSslWithKeyStore("classpath:test.jks"); } @Test public void basicSslFromFileSystem() throws Exception { testBasicSslWithKeyStore("src/test/resources/test.jks"); } @Test public void sslDisabled() throws Exception { AbstractServletWebServerFactory factory = getFactory(); Ssl ssl = getSsl(null, "password", "classpath:test.jks"); ssl.setEnabled(false); factory.setSsl(ssl); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); this.thrown.expect(SSLException.class); getResponse(getLocalUrl("https", "/hello"), requestFactory); } @Test public void sslGetScheme() throws Exception { // gh-2232 AbstractServletWebServerFactory factory = getFactory(); factory.setSsl(getSsl(null, "password", "src/test/resources/test.jks")); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/hello"), requestFactory)) .contains("scheme=https"); } @Test public void sslKeyAlias() throws Exception { AbstractServletWebServerFactory factory = getFactory(); Ssl ssl = getSsl(null, "password", "test-alias", "src/test/resources/test.jks"); factory.setSsl(ssl); ServletRegistrationBean<ExampleServlet> registration = new ServletRegistrationBean<>( new ExampleServlet(true, false), "/hello"); this.webServer = factory.getWebServer(registration); this.webServer.start(); TrustStrategy trustStrategy = new SerialNumberValidatingTrustSelfSignedStrategy( "77e7c302"); SSLContext sslContext = new SSLContextBuilder() .loadTrustMaterial(null, trustStrategy).build(); HttpClient httpClient = HttpClients.custom() .setSSLSocketFactory(new SSLConnectionSocketFactory(sslContext)).build(); String response = getResponse(getLocalUrl("https", "/hello"), new HttpComponentsClientHttpRequestFactory(httpClient)); assertThat(response).contains("scheme=https"); } @Test public void serverHeaderIsDisabledByDefaultWhenUsingSsl() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setSsl(getSsl(null, "password", "src/test/resources/test.jks")); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); ClientHttpResponse response = getClientResponse(getLocalUrl("https", "/hello"), HttpMethod.GET, new HttpComponentsClientHttpRequestFactory(httpClient)); assertThat(response.getHeaders().get("Server")).isNullOrEmpty(); } @Test public void serverHeaderCanBeCustomizedWhenUsingSsl() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setServerHeader("MyServer"); factory.setSsl(getSsl(null, "password", "src/test/resources/test.jks")); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); ClientHttpResponse response = getClientResponse(getLocalUrl("https", "/hello"), HttpMethod.GET, new HttpComponentsClientHttpRequestFactory(httpClient)); assertThat(response.getHeaders().get("Server")).containsExactly("MyServer"); } protected final void testBasicSslWithKeyStore(String keyStore) throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); factory.setSsl(getSsl(null, "password", keyStore)); this.webServer = factory.getWebServer(); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)) .isEqualTo("test"); } @Test public void pkcs12KeyStoreAndTrustStore() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); factory.setSsl(getSsl(ClientAuth.NEED, null, "classpath:test.p12", "classpath:test.p12", null, null)); this.webServer = factory.getWebServer(); this.webServer.start(); KeyStore keyStore = KeyStore.getInstance("pkcs12"); keyStore.load(new FileInputStream(new File("src/test/resources/test.p12")), "secret".toCharArray()); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()) .loadKeyMaterial(keyStore, "secret".toCharArray()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)) .isEqualTo("test"); } @Test public void sslNeedsClientAuthenticationSucceedsWithClientCertificate() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); factory.setSsl(getSsl(ClientAuth.NEED, "password", "classpath:test.jks", "classpath:test.jks", null, null)); this.webServer = factory.getWebServer(); this.webServer.start(); KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray()); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()) .loadKeyMaterial(keyStore, "password".toCharArray()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)) .isEqualTo("test"); } @Test(expected = IOException.class) public void sslNeedsClientAuthenticationFailsWithoutClientCertificate() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); factory.setSsl(getSsl(ClientAuth.NEED, "password", "classpath:test.jks")); this.webServer = factory.getWebServer(); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); getResponse(getLocalUrl("https", "/test.txt"), requestFactory); } @Test public void sslWantsClientAuthenticationSucceedsWithClientCertificate() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); factory.setSsl(getSsl(ClientAuth.WANT, "password", "classpath:test.jks")); this.webServer = factory.getWebServer(); this.webServer.start(); KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray()); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()) .loadKeyMaterial(keyStore, "password".toCharArray()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)) .isEqualTo("test"); } @Test public void sslWantsClientAuthenticationSucceedsWithoutClientCertificate() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); factory.setSsl(getSsl(ClientAuth.WANT, "password", "classpath:test.jks")); this.webServer = factory.getWebServer(); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)) .isEqualTo("test"); } @Test public void sslWithCustomSslStoreProvider() throws Exception { AbstractServletWebServerFactory factory = getFactory(); addTestTxtFile(factory); Ssl ssl = new Ssl(); ssl.setClientAuth(ClientAuth.NEED); ssl.setKeyPassword("password"); factory.setSsl(ssl); SslStoreProvider sslStoreProvider = mock(SslStoreProvider.class); given(sslStoreProvider.getKeyStore()).willReturn(loadStore()); given(sslStoreProvider.getTrustStore()).willReturn(loadStore()); factory.setSslStoreProvider(sslStoreProvider); this.webServer = factory.getWebServer(); this.webServer.start(); KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray()); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()) .loadKeyMaterial(keyStore, "password".toCharArray()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)) .isEqualTo("test"); verify(sslStoreProvider).getKeyStore(); verify(sslStoreProvider).getTrustStore(); } @Test public void disableJspServletRegistration() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.getJsp().setRegistered(false); this.webServer = factory.getWebServer(); assertThat(getJspServlet()).isNull(); } @Test public void cannotReadClassPathFiles() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ClientHttpResponse response = getClientResponse( getLocalUrl("/org/springframework/boot/SpringApplication.class")); assertThat(response.getStatusCode()).isEqualTo(HttpStatus.NOT_FOUND); } protected Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyStore) { return getSsl(clientAuth, keyPassword, keyStore, null, null, null); } private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, String keyStore) { return getSsl(clientAuth, keyPassword, keyAlias, keyStore, null, null, null); } private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyStore, String trustStore, String[] supportedProtocols, String[] ciphers) { return getSsl(clientAuth, keyPassword, null, keyStore, trustStore, supportedProtocols, ciphers); } private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, String keyStore, String trustStore, String[] supportedProtocols, String[] ciphers) { Ssl ssl = new Ssl(); ssl.setClientAuth(clientAuth); if (keyPassword != null) { ssl.setKeyPassword(keyPassword); } if (keyAlias != null) { ssl.setKeyAlias(keyAlias); } if (keyStore != null) { ssl.setKeyStore(keyStore); ssl.setKeyStorePassword("secret"); ssl.setKeyStoreType(getStoreType(keyStore)); } if (trustStore != null) { ssl.setTrustStore(trustStore); ssl.setTrustStorePassword("secret"); ssl.setTrustStoreType(getStoreType(trustStore)); } if (ciphers != null) { ssl.setCiphers(ciphers); } if (supportedProtocols != null) { ssl.setEnabledProtocols(supportedProtocols); } return ssl; } protected void testRestrictedSSLProtocolsAndCipherSuites(String[] protocols, String[] ciphers) throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setSsl(getSsl(null, "password", "src/test/resources/test.jks", null, protocols, ciphers)); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( new SSLContextBuilder() .loadTrustMaterial(null, new TrustSelfSignedStrategy()).build()); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) .build(); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( httpClient); assertThat(getResponse(getLocalUrl("https", "/hello"), requestFactory)) .contains("scheme=https"); } private String getStoreType(String keyStore) { return (keyStore.endsWith(".p12") ? "pkcs12" : null); } @Test public void defaultSessionTimeout() throws Exception { assertThat(getFactory().getSessionTimeout()).isEqualTo(30 * 60); } @Test public void persistSession() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setPersistSession(true); this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer.start(); String s1 = getResponse(getLocalUrl("/session")); String s2 = getResponse(getLocalUrl("/session")); this.webServer.stop(); this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer.start(); String s3 = getResponse(getLocalUrl("/session")); String message = "Session error s1=" + s1 + " s2=" + s2 + " s3=" + s3; assertThat(s2.split(":")[0]).as(message).isEqualTo(s1.split(":")[1]); assertThat(s3.split(":")[0]).as(message).isEqualTo(s2.split(":")[1]); } @Test public void persistSessionInSpecificSessionStoreDir() throws Exception { AbstractServletWebServerFactory factory = getFactory(); File sessionStoreDir = this.temporaryFolder.newFolder(); factory.setPersistSession(true); factory.setSessionStoreDir(sessionStoreDir); this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer.start(); getResponse(getLocalUrl("/session")); this.webServer.stop(); File[] dirContents = sessionStoreDir.listFiles(new FilenameFilter() { @Override public boolean accept(File dir, String name) { return !(".".equals(name) || "..".equals(name)); } }); assertThat(dirContents.length).isGreaterThan(0); } @Test public void getValidSessionStoreWhenSessionStoreNotSet() throws Exception { AbstractServletWebServerFactory factory = getFactory(); File dir = factory.getValidSessionStoreDir(false); assertThat(dir.getName()).isEqualTo("servlet-sessions"); assertThat(dir.getParentFile()).isEqualTo(new ApplicationTemp().getDir()); } @Test public void getValidSessionStoreWhenSessionStoreIsRelative() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setSessionStoreDir(new File("sessions")); File dir = factory.getValidSessionStoreDir(false); assertThat(dir.getName()).isEqualTo("sessions"); assertThat(dir.getParentFile()).isEqualTo(new ApplicationHome().getDir()); } @Test public void getValidSessionStoreWhenSessionStoreReferencesFile() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setSessionStoreDir(this.temporaryFolder.newFile()); this.thrown.expect(IllegalStateException.class); this.thrown.expectMessage("points to a file"); factory.getValidSessionStoreDir(false); } @Test public void compression() throws Exception { assertThat(doTestCompression(10000, null, null)).isTrue(); } @Test public void noCompressionForSmallResponse() throws Exception { assertThat(doTestCompression(100, null, null)).isFalse(); } @Test public void noCompressionForMimeType() throws Exception { String[] mimeTypes = new String[] { "text/html", "text/xml", "text/css" }; assertThat(doTestCompression(10000, mimeTypes, null)).isFalse(); } @Test public void noCompressionForUserAgent() throws Exception { assertThat(doTestCompression(10000, null, new String[] { "testUserAgent" })) .isFalse(); } @Test public void compressionWithoutContentSizeHeader() throws Exception { AbstractServletWebServerFactory factory = getFactory(); Compression compression = new Compression(); compression.setEnabled(true); factory.setCompression(compression); this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(false, true), "/hello")); this.webServer.start(); TestGzipInputStreamFactory inputStreamFactory = new TestGzipInputStreamFactory(); Map<String, InputStreamFactory> contentDecoderMap = Collections .singletonMap("gzip", (InputStreamFactory) inputStreamFactory); getResponse(getLocalUrl("/hello"), new HttpComponentsClientHttpRequestFactory(HttpClientBuilder.create() .setContentDecoderRegistry(contentDecoderMap).build())); assertThat(inputStreamFactory.wasCompressionUsed()).isTrue(); } @Test public void mimeMappingsAreCorrectlyConfigured() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(); Map<String, String> configuredMimeMappings = getActualMimeMappings(); Set<Entry<String, String>> entrySet = configuredMimeMappings.entrySet(); Collection<MimeMappings.Mapping> expectedMimeMappings = getExpectedMimeMappings(); for (Entry<String, String> entry : entrySet) { assertThat(expectedMimeMappings) .contains(new MimeMappings.Mapping(entry.getKey(), entry.getValue())); } for (MimeMappings.Mapping mapping : expectedMimeMappings) { assertThat(configuredMimeMappings).containsEntry(mapping.getExtension(), mapping.getMimeType()); } assertThat(configuredMimeMappings.size()).isEqualTo(expectedMimeMappings.size()); } @Test public void rootServletContextResource() throws Exception { AbstractServletWebServerFactory factory = getFactory(); final AtomicReference<URL> rootResource = new AtomicReference<>(); this.webServer = factory.getWebServer(new ServletContextInitializer() { @Override public void onStartup(ServletContext servletContext) throws ServletException { try { rootResource.set(servletContext.getResource("/")); } catch (MalformedURLException ex) { throw new ServletException(ex); } } }); this.webServer.start(); assertThat(rootResource.get()).isNotNull(); } @Test public void customServerHeader() throws Exception { AbstractServletWebServerFactory factory = getFactory(); factory.setServerHeader("MyServer"); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ClientHttpResponse response = getClientResponse(getLocalUrl("/hello")); assertThat(response.getHeaders().getFirst("server")).isEqualTo("MyServer"); } @Test public void serverHeaderIsDisabledByDefault() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(exampleServletRegistration()); this.webServer.start(); ClientHttpResponse response = getClientResponse(getLocalUrl("/hello")); assertThat(response.getHeaders().getFirst("server")).isNull(); } @Test public void portClashOfPrimaryConnectorResultsInPortInUseException() throws IOException { doWithBlockedPort(new BlockedPortAction() { @Override public void run(int port) { try { AbstractServletWebServerFactory factory = getFactory(); factory.setPort(port); AbstractServletWebServerFactoryTests.this.webServer = factory .getWebServer(); AbstractServletWebServerFactoryTests.this.webServer.start(); fail(); } catch (RuntimeException ex) { handleExceptionCausedByBlockedPort(ex, port); } } }); } @Test public void portClashOfSecondaryConnectorResultsInPortInUseException() throws IOException { doWithBlockedPort(new BlockedPortAction() { @Override public void run(int port) { try { AbstractServletWebServerFactory factory = getFactory(); factory.setPort(SocketUtils.findAvailableTcpPort(40000)); addConnector(port, factory); AbstractServletWebServerFactoryTests.this.webServer = factory .getWebServer(); AbstractServletWebServerFactoryTests.this.webServer.start(); fail(); } catch (RuntimeException ex) { handleExceptionCausedByBlockedPort(ex, port); } } }); } @Test public void localeCharsetMappingsAreConfigured() throws Exception { AbstractServletWebServerFactory factory = getFactory(); Map<Locale, Charset> mappings = new HashMap<>(); mappings.put(Locale.GERMAN, Charset.forName("UTF-8")); factory.setLocaleCharsetMappings(mappings); this.webServer = factory.getWebServer(); assertThat(getCharset(Locale.GERMAN).toString()).isEqualTo("UTF-8"); assertThat(getCharset(Locale.ITALIAN)).isNull(); } @Test public void jspServletInitParameters() throws Exception { Map<String, String> initParameters = new HashMap<>(); initParameters.put("a", "alpha"); AbstractServletWebServerFactory factory = getFactory(); factory.getJsp().setInitParameters(initParameters); this.webServer = factory.getWebServer(); Assume.assumeThat(getJspServlet(), notNullValue()); JspServlet jspServlet = getJspServlet(); assertThat(jspServlet.getInitParameter("a")).isEqualTo("alpha"); } @Test public void jspServletIsNotInDevelopmentModeByDefault() throws Exception { AbstractServletWebServerFactory factory = getFactory(); this.webServer = factory.getWebServer(); Assume.assumeThat(getJspServlet(), notNullValue()); JspServlet jspServlet = getJspServlet(); EmbeddedServletOptions options = (EmbeddedServletOptions) ReflectionTestUtils .getField(jspServlet, "options"); assertThat(options.getDevelopment()).isEqualTo(false); } protected abstract void addConnector(int port, AbstractServletWebServerFactory factory); protected abstract void handleExceptionCausedByBlockedPort(RuntimeException ex, int blockedPort); private boolean doTestCompression(int contentSize, String[] mimeTypes, String[] excludedUserAgents) throws Exception { String testContent = setUpFactoryForCompression(contentSize, mimeTypes, excludedUserAgents); TestGzipInputStreamFactory inputStreamFactory = new TestGzipInputStreamFactory(); Map<String, InputStreamFactory> contentDecoderMap = Collections .singletonMap("gzip", (InputStreamFactory) inputStreamFactory); String response = getResponse(getLocalUrl("/test.txt"), new HttpComponentsClientHttpRequestFactory( HttpClientBuilder.create().setUserAgent("testUserAgent") .setContentDecoderRegistry(contentDecoderMap).build())); assertThat(response).isEqualTo(testContent); return inputStreamFactory.wasCompressionUsed(); } protected String setUpFactoryForCompression(int contentSize, String[] mimeTypes, String[] excludedUserAgents) throws Exception { char[] chars = new char[contentSize]; Arrays.fill(chars, 'F'); String testContent = new String(chars); AbstractServletWebServerFactory factory = getFactory(); FileCopyUtils.copy(testContent, new FileWriter(this.temporaryFolder.newFile("test.txt"))); factory.setDocumentRoot(this.temporaryFolder.getRoot()); Compression compression = new Compression(); compression.setEnabled(true); if (mimeTypes != null) { compression.setMimeTypes(mimeTypes); } if (excludedUserAgents != null) { compression.setExcludedUserAgents(excludedUserAgents); } factory.setCompression(compression); this.webServer = factory.getWebServer(); this.webServer.start(); return testContent; } protected abstract Map<String, String> getActualMimeMappings(); protected Collection<MimeMappings.Mapping> getExpectedMimeMappings() { return MimeMappings.DEFAULT.getAll(); } protected abstract Charset getCharset(Locale locale); private void addTestTxtFile(AbstractServletWebServerFactory factory) throws IOException { FileCopyUtils.copy("test", new FileWriter(this.temporaryFolder.newFile("test.txt"))); factory.setDocumentRoot(this.temporaryFolder.getRoot()); } protected String getLocalUrl(String resourcePath) { return getLocalUrl("http", resourcePath); } protected String getLocalUrl(String scheme, String resourcePath) { return scheme + "://localhost:" + this.webServer.getPort() + resourcePath; } protected String getLocalUrl(int port, String resourcePath) { return "http://localhost:" + port + resourcePath; } protected String getResponse(String url, String... headers) throws IOException, URISyntaxException { return getResponse(url, HttpMethod.GET, headers); } protected String getResponse(String url, HttpMethod method, String... headers) throws IOException, URISyntaxException { ClientHttpResponse response = getClientResponse(url, method, headers); try { return StreamUtils.copyToString(response.getBody(), Charset.forName("UTF-8")); } finally { response.close(); } } protected String getResponse(String url, HttpComponentsClientHttpRequestFactory requestFactory, String... headers) throws IOException, URISyntaxException { return getResponse(url, HttpMethod.GET, requestFactory, headers); } protected String getResponse(String url, HttpMethod method, HttpComponentsClientHttpRequestFactory requestFactory, String... headers) throws IOException, URISyntaxException { ClientHttpResponse response = getClientResponse(url, method, requestFactory, headers); try { return StreamUtils.copyToString(response.getBody(), Charset.forName("UTF-8")); } finally { response.close(); } } protected ClientHttpResponse getClientResponse(String url, String... headers) throws IOException, URISyntaxException { return getClientResponse(url, HttpMethod.GET, headers); } protected ClientHttpResponse getClientResponse(String url, HttpMethod method, String... headers) throws IOException, URISyntaxException { return getClientResponse(url, method, new HttpComponentsClientHttpRequestFactory() { @Override protected HttpContext createHttpContext(HttpMethod httpMethod, URI uri) { return AbstractServletWebServerFactoryTests.this.httpClientContext; } }, headers); } protected ClientHttpResponse getClientResponse(String url, HttpMethod method, HttpComponentsClientHttpRequestFactory requestFactory, String... headers) throws IOException, URISyntaxException { ClientHttpRequest request = requestFactory.createRequest(new URI(url), method); request.getHeaders().add("Cookie", "JSESSIONID=" + "123"); for (String header : headers) { String[] parts = header.split(":"); request.getHeaders().add(parts[0], parts[1]); } ClientHttpResponse response = request.execute(); return response; } protected void assertForwardHeaderIsUsed(ServletWebServerFactory factory) throws IOException, URISyntaxException { this.webServer = factory.getWebServer( new ServletRegistrationBean<>(new ExampleServlet(true, false), "/hello")); this.webServer.start(); assertThat(getResponse(getLocalUrl("/hello"), "X-Forwarded-For:140.211.11.130")) .contains("remoteaddr=140.211.11.130"); } protected abstract AbstractServletWebServerFactory getFactory(); protected abstract org.apache.jasper.servlet.JspServlet getJspServlet() throws Exception; protected ServletContextInitializer exampleServletRegistration() { return new ServletRegistrationBean<>(new ExampleServlet(), "/hello"); } @SuppressWarnings("serial") private ServletContextInitializer errorServletRegistration() { ServletRegistrationBean<ExampleServlet> bean = new ServletRegistrationBean<>( new ExampleServlet() { @Override public void service(ServletRequest request, ServletResponse response) throws ServletException, IOException { throw new RuntimeException("Planned"); } }, "/bang"); bean.setName("error"); return bean; } protected final ServletContextInitializer sessionServletRegistration() { ServletRegistrationBean<ExampleServlet> bean = new ServletRegistrationBean<>( new ExampleServlet() { @Override public void service(ServletRequest request, ServletResponse response) throws ServletException, IOException { HttpSession session = ((HttpServletRequest) request) .getSession(true); long value = System.currentTimeMillis(); Object existing = session.getAttribute("boot"); session.setAttribute("boot", value); PrintWriter writer = response.getWriter(); writer.append(String.valueOf(existing) + ":" + value); } }, "/session"); bean.setName("session"); return bean; } protected final void doWithBlockedPort(BlockedPortAction action) throws IOException { int port = SocketUtils.findAvailableTcpPort(40000); ServerSocket serverSocket = new ServerSocket(); for (int i = 0; i < 10; i++) { try { serverSocket.bind(new InetSocketAddress(port)); break; } catch (Exception ex) { } } try { action.run(port); } finally { serverSocket.close(); } } private KeyStore loadStore() throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException { KeyStore keyStore = KeyStore.getInstance("JKS"); Resource resource = new ClassPathResource("test.jks"); InputStream inputStream = resource.getInputStream(); try { keyStore.load(inputStream, "secret".toCharArray()); return keyStore; } finally { inputStream.close(); } } private class TestGzipInputStreamFactory implements InputStreamFactory { private final AtomicBoolean requested = new AtomicBoolean(false); @Override public InputStream create(InputStream in) throws IOException { if (this.requested.get()) { throw new IllegalStateException( "On deflated InputStream already requested"); } this.requested.set(true); return new GZIPInputStream(in); } public boolean wasCompressionUsed() { return this.requested.get(); } } @SuppressWarnings("serial") private static class InitCountingServlet extends GenericServlet { private int initCount; @Override public void init() throws ServletException { this.initCount++; } @Override public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException { } public int getInitCount() { return this.initCount; } }; public interface BlockedPortAction { void run(int port); } /** * {@link TrustSelfSignedStrategy} that also validates certificate serial number. */ private static final class SerialNumberValidatingTrustSelfSignedStrategy extends TrustSelfSignedStrategy { private final String serialNumber; private SerialNumberValidatingTrustSelfSignedStrategy(String serialNumber) { this.serialNumber = serialNumber; } @Override public boolean isTrusted(X509Certificate[] chain, String authType) throws CertificateException { String hexSerialNumber = chain[0].getSerialNumber().toString(16); boolean isMatch = hexSerialNumber.equals(this.serialNumber); return super.isTrusted(chain, authType) && isMatch; } } }