package io.undertow.server.handlers.proxy;
import io.undertow.Undertow;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.ServerConnection;
import io.undertow.server.handlers.ResponseCodeHandler;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.HttpClientUtils;
import io.undertow.testutils.ProxyIgnore;
import io.undertow.testutils.TestHttpClient;
import io.undertow.util.StatusCodes;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.conn.PoolingClientConnectionManager;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@RunWith(DefaultServer.class)
@ProxyIgnore
public class LoadBalancerConnectionPoolingTestCase {
public static final int TTL = 2000;
private static Undertow undertow;
private static final Set<ServerConnection> activeConnections = Collections.newSetFromMap(new ConcurrentHashMap<>());
static final String host = DefaultServer.getHostAddress("default");
static int port = DefaultServer.getHostPort("default");
@BeforeClass
public static void before() throws Exception {
ProxyHandler proxyHandler = new ProxyHandler(new LoadBalancingProxyClient()
.setConnectionsPerThread(1)
.setSoftMaxConnectionsPerThread(0)
.setTtl(TTL)
.setMaxQueueSize(1000)
.addHost(new URI("http", null, host, port, null, null, null), "s1")
, 10000, ResponseCodeHandler.HANDLE_404);
// Default server uses 8 io threads which is hard to test against
undertow = Undertow.builder()
.setIoThreads(1)
.addHttpListener(port + 1, host)
.setHandler(proxyHandler)
.build();
undertow.start();
DefaultServer.setRootHandler(new HttpHandler() {
@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
final ServerConnection con = exchange.getConnection();
if(!activeConnections.contains(con)) {
System.out.println("added " + con);
activeConnections.add(con);
con.addCloseListener(new ServerConnection.CloseListener() {
@Override
public void closed(ServerConnection connection) {
System.out.println("Closed " + connection);
activeConnections.remove(connection);
}
});
}
}
});
}
@AfterClass
public static void after() {
undertow.stop();
}
@Test
public void shouldReduceConnectionPool() throws Exception {
ExecutorService executorService = Executors.newFixedThreadPool(10);
PoolingClientConnectionManager conman = new PoolingClientConnectionManager();
conman.setDefaultMaxPerRoute(20);
final TestHttpClient client = new TestHttpClient(conman);
int requests = 20;
final CountDownLatch latch = new CountDownLatch(requests);
long ttlStartExpire = TTL + System.currentTimeMillis();
try {
for (int i = 0; i < requests; ++i) {
executorService.submit(new Runnable() {
@Override
public void run() {
HttpGet get = new HttpGet("http://" + host + ":" + (port + 1));
try {
HttpResponse response = client.execute(get);
Assert.assertEquals(StatusCodes.OK, response.getStatusLine().getStatusCode());
HttpClientUtils.readResponse(response);
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
latch.countDown();
}
}
});
}
if(!latch.await(2000, TimeUnit.MILLISECONDS)) {
Assert.fail();
}
} finally {
client.getConnectionManager().shutdown();
executorService.shutdown();
}
if(activeConnections.size() != 1) {
//if the test is slow this line could be hit after the expire time
//uncommon, but we guard against it to prevent intermittent failures
if(System.currentTimeMillis() < ttlStartExpire) {
Assert.fail("there should still be a connection");
}
}
long end = System.currentTimeMillis() + (TTL * 3);
while (!activeConnections.isEmpty() && System.currentTimeMillis() < end) {
Thread.sleep(100);
}
Assert.assertEquals(0, activeConnections.size());
}
}