//
// ========================================================================
// Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.server;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.servlet.DispatcherType;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.LeakTrackingByteBufferPool;
import org.eclipse.jetty.io.MappedByteBufferPool;
import org.eclipse.jetty.server.handler.AbstractHandler;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.toolchain.test.TestTracker;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.Scheduler;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@RunWith(Parameterized.class)
public class ThreadStarvationTest
{
final static int BUFFER_SIZE=1024*1024;
final static int BUFFERS=64;
final static int THREADS=5;
final static int CLIENTS=THREADS+2;
@Rule
public TestTracker tracker = new TestTracker();
interface ConnectorProvider {
ServerConnector newConnector(Server server, int acceptors, int selectors);
}
interface ClientSocketProvider {
Socket newSocket(String host, int port) throws IOException;
}
@Parameterized.Parameters(name = "{0}")
public static List<Object[]> params()
{
List<Object[]> params = new ArrayList<>();
// HTTP
ConnectorProvider http = (server, acceptors, selectors) -> new ServerConnector(server, acceptors, selectors);
ClientSocketProvider httpClient = (host, port) -> new Socket(host, port);
params.add(new Object[]{ "http", http, httpClient });
// HTTPS/SSL/TLS
ConnectorProvider https = (server, acceptors, selectors) -> {
Path keystorePath = MavenTestingUtils.getTestResourcePath("keystore");
SslContextFactory sslContextFactory = new SslContextFactory();
sslContextFactory.setKeyStorePath(keystorePath.toString());
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
sslContextFactory.setTrustStorePath(keystorePath.toString());
sslContextFactory.setTrustStorePassword("storepwd");
ByteBufferPool pool = new LeakTrackingByteBufferPool(new MappedByteBufferPool.Tagged());
HttpConnectionFactory httpConnectionFactory = new HttpConnectionFactory();
ServerConnector connector = new ServerConnector(server,(Executor)null,(Scheduler)null,
pool, acceptors, selectors,
AbstractConnectionFactory.getFactories(sslContextFactory,httpConnectionFactory));
SecureRequestCustomizer secureRequestCustomer = new SecureRequestCustomizer();
secureRequestCustomer.setSslSessionAttribute("SSL_SESSION");
httpConnectionFactory.getHttpConfiguration().addCustomizer(secureRequestCustomer);
return connector;
};
ClientSocketProvider httpsClient = new ClientSocketProvider()
{
private SSLContext sslContext;
{
try
{
HttpsURLConnection.setDefaultHostnameVerifier((hostname, session)-> true);
sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, SslContextFactory.TRUST_ALL_CERTS, new java.security.SecureRandom());
HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory());
}
catch(Exception e)
{
e.printStackTrace();
throw new RuntimeException(e);
}
}
@Override
public Socket newSocket(String host, int port) throws IOException
{
return sslContext.getSocketFactory().createSocket(host,port);
}
};
params.add(new Object[]{ "https/ssl/tls", https, httpsClient });
return params;
}
private final ConnectorProvider connectorProvider;
private final ClientSocketProvider clientSocketProvider;
private QueuedThreadPool _threadPool;
private Server _server;
private ServerConnector _connector;
public ThreadStarvationTest(String testType, ConnectorProvider connectorProvider, ClientSocketProvider clientSocketProvider)
{
this.connectorProvider = connectorProvider;
this.clientSocketProvider = clientSocketProvider;
}
private Server prepareServer(Handler handler)
{
_threadPool = new QueuedThreadPool();
_threadPool.setMinThreads(THREADS);
_threadPool.setMaxThreads(THREADS);
_threadPool.setDetailedDump(true);
_server = new Server(_threadPool);
int acceptors = 1;
int selectors = 1;
_connector = connectorProvider.newConnector(_server, acceptors, selectors);
_server.addConnector(_connector);
_server.setHandler(handler);
return _server;
}
@After
public void dispose() throws Exception
{
_server.stop();
}
@Test
public void testReadInput() throws Exception
{
prepareServer(new ReadHandler()).start();
try(Socket client = clientSocketProvider.newSocket("localhost", _connector.getLocalPort()))
{
client.setSoTimeout(10000);
OutputStream os = client.getOutputStream();
InputStream is = client.getInputStream();
String request = "" +
"GET / HTTP/1.0\r\n" +
"Host: localhost\r\n" +
"Content-Length: 10\r\n" +
"\r\n" +
"0123456789\r\n";
os.write(request.getBytes(StandardCharsets.UTF_8));
os.flush();
String response = IO.toString(is);
assertEquals(-1, is.read());
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("Read Input 10"));
}
}
@Test
public void testReadStarvation() throws Exception
{
prepareServer(new ReadHandler());
_server.start();
ExecutorService clientExecutors = Executors.newFixedThreadPool(CLIENTS);
List<Callable<String>> clientTasks = new ArrayList<>();
for(int i=0; i<CLIENTS; i++) {
clientTasks.add(() ->
{
try (Socket client = clientSocketProvider.newSocket("localhost", _connector.getLocalPort());
OutputStream out = client.getOutputStream();
InputStream in = client.getInputStream())
{
client.setSoTimeout(10000);
String request = "" +
"PUT / HTTP/1.0\r\n" +
"host: localhost\r\n" +
"content-length: 10\r\n" +
"\r\n" +
"1";
// Write partial request
out.write(request.getBytes(StandardCharsets.UTF_8));
out.flush();
// Finish Request
Thread.sleep(1500);
out.write(("234567890\r\n").getBytes(StandardCharsets.UTF_8));
out.flush();
// Read Response
String response = IO.toString(in);
assertEquals(-1, in.read());
return response;
}
});
}
try
{
List<Future<String>> responses = clientExecutors.invokeAll(clientTasks, 60, TimeUnit.SECONDS);
for (Future<String> responseFut : responses)
{
String response = responseFut.get();
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("Read Input 10"));
}
} finally
{
clientExecutors.shutdownNow();
}
}
protected static class ReadHandler extends AbstractHandler
{
@Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{
baseRequest.setHandled(true);
if(request.getDispatcherType() == DispatcherType.REQUEST)
{
response.setStatus(200);
int l = request.getContentLength();
int r = 0;
while (r < l)
{
if (request.getInputStream().read() >= 0)
r++;
}
response.getOutputStream().write(("Read Input " + r + "\r\n").getBytes());
}
else
{
response.sendError(HttpStatus.INTERNAL_SERVER_ERROR_500);
}
}
}
@Test
public void testWriteStarvation() throws Exception
{
prepareServer(new WriteHandler());
_server.start();
ExecutorService clientExecutors = Executors.newFixedThreadPool(CLIENTS);
List<Callable<Long>> clientTasks = new ArrayList<>();
for(int i=0; i<CLIENTS; i++) {
clientTasks.add(() ->
{
try (Socket client = clientSocketProvider.newSocket("localhost", _connector.getLocalPort());
OutputStream out = client.getOutputStream();
InputStream in = client.getInputStream())
{
client.setSoTimeout(30000);
String request = "" +
"GET / HTTP/1.0\r\n" +
"host: localhost\r\n" +
"\r\n";
// Write GET request
out.write(request.getBytes(StandardCharsets.UTF_8));
out.flush();
TimeUnit.MILLISECONDS.sleep(1500);
// Read Response
long bodyCount = 0;
long len;
byte buf[] = new byte[1024];
while((len = in.read(buf,0,buf.length)) != -1)
{
for(int x=0; x<len; x++)
{
if(buf[x] == '!') bodyCount++;
}
}
return bodyCount;
}
});
}
try
{
List<Future<Long>> responses = clientExecutors.invokeAll(clientTasks, 60, TimeUnit.SECONDS);
long expected = BUFFERS * BUFFER_SIZE;
for (Future<Long> responseFut : responses)
{
Long bodyCount = responseFut.get();
assertThat(bodyCount.longValue(), is(expected));
}
} finally
{
clientExecutors.shutdownNow();
}
}
protected static class WriteHandler extends AbstractHandler
{
byte[] content=new byte[BUFFER_SIZE];
{
// Using a character that will not show up in a HTTP response header
Arrays.fill(content,(byte)'!');
}
@Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{
baseRequest.setHandled(true);
response.setStatus(200);
OutputStream out = response.getOutputStream();
for (int i=0;i<BUFFERS;i++)
{
out.write(content);
out.flush();
}
}
}
}