package io.undertow.server.handlers.proxy;
import io.undertow.Handlers;
import io.undertow.Undertow;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.PathHandler;
import io.undertow.server.handlers.ResponseCodeHandler;
import io.undertow.testutils.TestHttpClient;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.junit.After;
import org.junit.Test;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.URI;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
public class ProxyPathHandlingTest {
private final TargetServer targetServer = new TargetServer();
private final ProxyServer proxyServer = new ProxyServer(targetServer.uri);
@After
public void cleanup() {
targetServer.stop();
proxyServer.stop();
}
@Test
public void prefixRootToRoot() throws Exception {
proxyServer.proxyPrefixPath("/", "/");
isProxied("", "/");
isProxied("/", "/");
isProxied("/foo", "/foo");
}
@Test
public void prefixRootToPath() throws Exception {
proxyServer.proxyPrefixPath("/", "/path");
isProxied("", "/path/");
isProxied("/", "/path/");
isProxied("/foo", "/path/foo");
}
@Test
public void prefixPathToPath() throws Exception {
proxyServer.proxyPrefixPath("/path", "/path");
isProxied("/path", "/path");
isProxied("/path/", "/path/");
isProxied("/path/foo", "/path/foo");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
}
@Test
public void prefixPathToRoot() throws Exception {
proxyServer.proxyPrefixPath("/path", "/");
isProxied("/path", "/");
isProxied("/path/", "/");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
}
@Test
public void prefixPathSlashToRoot() throws Exception {
proxyServer.proxyPrefixPath("/path/", "/");
isProxied("/path", "/");
isProxied("/path/", "/");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
}
@Test
public void exactRootToRoot() throws Exception {
proxyServer.proxyExactPath("/", "/");
isProxied("", "/");
isProxied("/", "/");
isNotProxied("/foo");
}
@Test
public void exactRootToPath() throws Exception {
proxyServer.proxyExactPath("/", "/path");
isProxied("", "/path");
isProxied("/", "/path");
isNotProxied("/foo");
}
@Test
public void exactRootToPathSlash() throws Exception {
proxyServer.proxyExactPath("/", "/path/");
isProxied("", "/path/");
isProxied("/", "/path/");
isNotProxied("/foo");
}
@Test
public void exactPathToRoot() throws Exception {
proxyServer.proxyExactPath("/path", "/");
isProxied("/path", "/");
isProxied("/path/", "/");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
isNotProxied("/path/foo");
}
@Test
public void exactPathSlashToRoot() throws Exception {
proxyServer.proxyExactPath("/path/", "/");
isProxied("/path", "/");
isProxied("/path/", "/");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
isNotProxied("/path/foo");
}
@Test
public void exactPathToPath() throws Exception {
proxyServer.proxyExactPath("/path", "/path");
isProxied("/path", "/path");
isProxied("/path/", "/path");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
isNotProxied("/path/foo");
}
@Test
public void exactPathToPathSlash() throws Exception {
proxyServer.proxyExactPath("/path", "/path/");
isProxied("/path", "/path/");
isProxied("/path/", "/path/");
isNotProxied("");
isNotProxied("/");
isNotProxied("/foo");
isNotProxied("/path/foo");
}
private void isProxied(String requestPath, String expectedTargetPath) throws IOException {
assertEquals(200, httpGet(requestPath));
assertEquals(expectedTargetPath, targetServer.gotRequest(true));
}
private void isNotProxied(String requestPath) throws IOException {
assertEquals(404, httpGet(requestPath));
assertNull(targetServer.gotRequest(false));
}
private int httpGet(String path) throws IOException {
TestHttpClient http = new TestHttpClient();
HttpResponse response = http.execute(new HttpGet(proxyServer.uri + path));
return response.getStatusLine().getStatusCode();
}
private static class ProxyServer {
private final int port = FreePort.find();
private final Undertow server;
private final PathHandler pathHandler = Handlers.path();
final String uri = "http://localhost:" + port;
private final String targetUri;
ProxyServer(String targetUri) {
this.targetUri = targetUri;
server = Undertow.builder()
.addHttpListener(port, "0.0.0.0")
.setHandler(pathHandler)
.build();
server.start();
}
void proxyPrefixPath(String proxyPath, String targetPath) {
pathHandler.addPrefixPath(proxyPath, proxyHandler(targetPath));
}
void proxyExactPath(String proxyPath, String targetPath) {
pathHandler.addExactPath(proxyPath, proxyHandler(targetPath));
}
void stop() {
server.stop();
}
private HttpHandler proxyHandler(String targetPath) {
return new ProxyHandler(
new SimpleProxyClientProvider(URI.create(targetUri + targetPath)),
ResponseCodeHandler.HANDLE_404);
}
}
private static class TargetServer {
private final int port = FreePort.find();
private final Undertow server;
final String uri = "http://localhost:" + port;
private final LinkedBlockingQueue<String> gotRequestQueue = new LinkedBlockingQueue<>();
TargetServer() {
server = Undertow.builder()
.addHttpListener(port, "0.0.0.0")
.setHandler(new HttpHandler() {
@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
gotRequestQueue.add(URI.create(exchange.getRequestURL()).getPath());
}
})
.build();
server.start();
}
void stop() {
server.stop();
}
String gotRequest(boolean wait) {
String url = null;
try {
url = gotRequestQueue.poll( wait ? 10000 : 10, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return url;
}
}
private static class FreePort {
static int find() {
int port = 0;
while (port == 0) {
ServerSocket socket = null;
try {
socket = new ServerSocket(0);
port = socket.getLocalPort();
} catch (IOException e) {
throw new RuntimeException("Failed finding free port", e);
} finally {
try {
if (socket != null) socket.close();
} catch (IOException ignore) {
}
}
}
return port;
}
}
}