package org.eclipse.jetty.servlets;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.net.Socket;
import javax.servlet.Filter;
import javax.servlet.Servlet;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.testing.ServletTester;
import org.eclipse.jetty.util.IO;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.Test;
/**
* @version $Revision$ $Date$
*/
public abstract class AbstractDoSFilterTest
{
private static ServletTester _tester;
private static String _host;
private static int _port;
private static long _requestMaxTime = 200;
private static FilterHolder _dosFilter;
private static FilterHolder _timeoutFilter;
public static void startServer(Class<? extends Filter> filter) throws Exception
{
_tester = new ServletTester();
HttpURI uri = new HttpURI(_tester.createChannelConnector(true));
_host = uri.getHost();
_port = uri.getPort();
_tester.setContextPath("/ctx");
_tester.addServlet(TestServlet.class, "/*");
_dosFilter = _tester.addFilter(filter, "/dos/*", 0);
_dosFilter.setInitParameter("maxRequestsPerSec", "4");
_dosFilter.setInitParameter("delayMs", "200");
_dosFilter.setInitParameter("throttledRequests", "1");
_dosFilter.setInitParameter("waitMs", "10");
_dosFilter.setInitParameter("throttleMs", "4000");
_dosFilter.setInitParameter("remotePort", "false");
_dosFilter.setInitParameter("insertHeaders", "true");
_timeoutFilter = _tester.addFilter(filter, "/timeout/*", 0);
_timeoutFilter.setInitParameter("maxRequestsPerSec", "4");
_timeoutFilter.setInitParameter("delayMs", "200");
_timeoutFilter.setInitParameter("throttledRequests", "1");
_timeoutFilter.setInitParameter("waitMs", "10");
_timeoutFilter.setInitParameter("throttleMs", "4000");
_timeoutFilter.setInitParameter("remotePort", "false");
_timeoutFilter.setInitParameter("insertHeaders", "true");
_timeoutFilter.setInitParameter("maxRequestMs", _requestMaxTime + "");
_tester.start();
}
@AfterClass
public static void stopServer() throws Exception
{
_tester.stop();
}
@Before
public void startFilters() throws Exception
{
_dosFilter.start();
_timeoutFilter.start();
}
@After
public void stopFilters() throws Exception
{
_timeoutFilter.stop();
_dosFilter.stop();
}
private String doRequests(String requests, int loops, long pause0, long pause1, String request) throws Exception
{
Socket socket = new Socket(_host, _port);
socket.setSoTimeout(30000);
for (int i=loops;i-->0;)
{
socket.getOutputStream().write(requests.getBytes("UTF-8"));
socket.getOutputStream().flush();
if (i>0 && pause0>0)
Thread.sleep(pause0);
}
if (pause1>0)
Thread.sleep(pause1);
socket.getOutputStream().write(request.getBytes("UTF-8"));
socket.getOutputStream().flush();
String response;
if (requests.contains("/unresponsive"))
{
// don't read in anything, forcing the request to time out
Thread.sleep(_requestMaxTime * 2);
response = IO.toString(socket.getInputStream(),"UTF-8");
}
else
{
response = IO.toString(socket.getInputStream(),"UTF-8");
}
socket.close();
return response;
}
private int count(String responses,String substring)
{
int count=0;
int i=responses.indexOf(substring);
while (i>=0)
{
count++;
i=responses.indexOf(substring,i+substring.length());
}
return count;
}
@Test
public void testEvenLowRateIP() throws Exception
{
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request,11,300,300,last);
assertEquals(12,count(responses,"HTTP/1.1 200 OK"));
assertEquals(0,count(responses,"DoSFilter:"));
}
@Test
public void testBurstLowRateIP() throws Exception
{
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,2,1100,1100,last);
assertEquals(9,count(responses,"HTTP/1.1 200 OK"));
assertEquals(0,count(responses,"DoSFilter:"));
}
@Test
public void testDelayedIP() throws Exception
{
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request+request,2,1100,1100,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
assertEquals(2,count(responses,"DoSFilter: delayed"));
}
@Test
public void testThrottledIP() throws Exception
{
Thread other = new Thread()
{
@Override
public void run()
{
try
{
// Cause a delay, then sleep while holding pass
String request="GET /ctx/dos/sleeper HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/sleeper?sleep=2000 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
}
catch(Exception e)
{
e.printStackTrace();
}
}
};
other.start();
Thread.sleep(1500);
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
//System.out.println("responses are " + responses);
assertEquals("200 OK responses", 5,count(responses,"HTTP/1.1 200 OK"));
assertEquals("delayed responses", 1,count(responses,"DoSFilter: delayed"));
assertEquals("throttled responses", 1,count(responses,"DoSFilter: throttled"));
assertEquals("unavailable responses", 0,count(responses,"DoSFilter: unavailable"));
other.join();
}
@Test
public void testUnavailableIP() throws Exception
{
Thread other = new Thread()
{
@Override
public void run()
{
try
{
// Cause a delay, then sleep while holding pass
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test?sleep=5000 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
}
catch(Exception e)
{
e.printStackTrace();
}
}
};
other.start();
Thread.sleep(500);
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
assertEquals(4,count(responses,"HTTP/1.1 200 OK"));
assertEquals(1,count(responses,"HTTP/1.1 503"));
assertEquals(1,count(responses,"DoSFilter: delayed"));
assertEquals(1,count(responses,"DoSFilter: throttled"));
assertEquals(1,count(responses,"DoSFilter: unavailable"));
other.join();
}
@Test
public void testSessionTracking() throws Exception
{
// get a session, first
String requestSession="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response=doRequests("",1,0,0,requestSession);
String sessionId=response.substring(response.indexOf("Set-Cookie: ")+12, response.indexOf(";"));
// all other requests use this session
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n";
String responses = doRequests(request+request+request+request+request,2,1100,1100,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
assertEquals(2,count(responses,"DoSFilter: delayed"));
}
@Test
public void testMultipleSessionTracking() throws Exception
{
// get some session ids, first
String requestSession="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n";
String closeRequest="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response=doRequests(requestSession+requestSession,1,0,0,closeRequest);
String[] sessions = response.split("\r\n\r\n");
String sessionId1=sessions[0].substring(sessions[0].indexOf("Set-Cookie: ")+12, sessions[0].indexOf(";"));
String sessionId2=sessions[1].substring(sessions[1].indexOf("Set-Cookie: ")+12, sessions[1].indexOf(";"));
// alternate between sessions
String request1="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n";
String request2="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n";
// ensure the sessions are new
String responses = doRequests(request1+request2,1,1100,1100,last);
Thread.sleep(1000);
responses = doRequests(request1+request2+request1+request2+request1,2,1100,1100,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
assertEquals(0,count(responses,"DoSFilter: delayed"));
// alternate between sessions
responses = doRequests(request1+request2+request1+request2+request1,2,350,550,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
int delayedRequests = count(responses,"DoSFilter: delayed");
assertTrue("delayedRequests: " + delayedRequests + " is not between 2 and 3",delayedRequests >= 2 && delayedRequests <= 3);
}
@Test
public void testUnresponsiveClient() throws Exception
{
int numRequests = 1000;
String last="GET /ctx/timeout/unresponsive?lines="+numRequests+" HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests("",0,0,0,last);
// was expired, and stopped before reaching the end of the requests
int responseLines = count(responses, "Line:");
assertTrue(responses.contains("DoSFilter: timeout"));
assertTrue(responseLines > 0 && responseLines < numRequests);
}
public static class TestServlet extends HttpServlet implements Servlet
{
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
if (request.getParameter("session")!=null)
request.getSession(true);
if (request.getParameter("sleep")!=null)
{
try
{
Thread.sleep(Long.parseLong(request.getParameter("sleep")));
}
catch(InterruptedException e)
{
}
}
if (request.getParameter("lines")!=null)
{
int count = Integer.parseInt(request.getParameter("lines"));
for(int i = 0; i < count; ++i)
{
response.getWriter().append("Line: " + i+"\n");
response.flushBuffer();
try
{
Thread.sleep(10);
}
catch(InterruptedException e)
{
}
}
}
response.setContentType("text/plain");
}
}
}