//
// ========================================================================
// Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.server.handler;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import java.io.IOException;
import java.net.Socket;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.LocalConnector;
import org.eclipse.jetty.server.NetworkConnector;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
public class ThreadLimitHandlerTest
{
private Server _server;
private NetworkConnector _connector;
private LocalConnector _local;
@Before
public void before()
throws Exception
{
_server = new Server();
_connector = new ServerConnector(_server);
_local = new LocalConnector(_server);
_server.setConnectors(new Connector[] { _local,_connector });
}
@After
public void after()
throws Exception
{
_server.stop();
}
@Test
public void testNoForwardHeaders() throws Exception
{
AtomicReference<String> last = new AtomicReference<>();
ThreadLimitHandler handler = new ThreadLimitHandler(null,false)
{
@Override
protected int getThreadLimit(String ip)
{
last.set(ip);
return super.getThreadLimit(ip);
}
};
handler.setHandler(new AbstractHandler()
{
@Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{
baseRequest.setHandled(true);
response.setStatus(HttpStatus.OK_200);
}
});
_server.setHandler(handler);
_server.start();
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
}
@Test
public void testXForwardForHeaders() throws Exception
{
AtomicReference<String> last = new AtomicReference<>();
ThreadLimitHandler handler = new ThreadLimitHandler("X-Forwarded-For")
{
@Override
protected int getThreadLimit(String ip)
{
last.set(ip);
return super.getThreadLimit(ip);
}
};
_server.setHandler(handler);
_server.start();
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("1.2.3.4"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nX-Forwarded-For: 6.6.6.6,1.2.3.4\r\nForwarded: for=1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("1.2.3.4"));
}
@Test
public void testForwardHeaders() throws Exception
{
AtomicReference<String> last = new AtomicReference<>();
ThreadLimitHandler handler = new ThreadLimitHandler("Forwarded")
{
@Override
protected int getThreadLimit(String ip)
{
last.set(ip);
return super.getThreadLimit(ip);
}
};
_server.setHandler(handler);
_server.start();
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("0.0.0.0"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("1.2.3.4"));
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nForwarded: for=6.6.6.6; for=1.2.3.4\r\nX-Forwarded-For: 6.6.6.6\r\nForwarded: proto=https\r\n\r\n");
Assert.assertThat(last.get(),Matchers.is("1.2.3.4"));
}
@Test
public void testLimit() throws Exception
{
ThreadLimitHandler handler = new ThreadLimitHandler("Forwarded");
handler.setThreadLimit(4);
AtomicInteger count = new AtomicInteger(0);
AtomicInteger total = new AtomicInteger(0);
CountDownLatch latch = new CountDownLatch(1);
handler.setHandler(new AbstractHandler()
{
@Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{
baseRequest.setHandled(true);
response.setStatus(HttpStatus.OK_200);
if ("/other".equals(target))
return;
try
{
count.incrementAndGet();
total.incrementAndGet();
latch.await();
}
catch (InterruptedException e)
{
throw new ServletException(e);
}
finally
{
count.decrementAndGet();
}
}
});
_server.setHandler(handler);
_server.start();
Socket[] client = new Socket[10];
for (int i=0;i<client.length;i++)
{
client[i]=new Socket("127.0.0.1",_connector.getLocalPort());
client[i].getOutputStream().write(("GET /"+i+" HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n").getBytes());
client[i].getOutputStream().flush();
}
long wait = System.nanoTime() + TimeUnit.SECONDS.toNanos(10);
while(count.get()<4 && System.nanoTime()<wait)
Thread.sleep(1);
assertThat(count.get(),is(4));
// check that other requests are not blocked
assertThat(_local.getResponse("GET /other HTTP/1.0\r\nForwarded: for=6.6.6.6\r\n\r\n"),Matchers.containsString(" 200 OK"));
// let the other requests go
latch.countDown();
while(total.get()<10 && System.nanoTime()<wait)
Thread.sleep(10);
assertThat(total.get(),is(10));
while(count.get()>0 && System.nanoTime()<wait)
Thread.sleep(10);
assertThat(count.get(),is(0));
}
}