package gcom.util.web;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
public class RequestControlFilter implements Filter {
public void init(FilterConfig config) throws ServletException {
// parse all of the initialization parameters, collecting the exclude
// patterns and the max wait parameters
Enumeration enumeration = config.getInitParameterNames();
excludePatterns = new LinkedList();
maxWaitDurations = new HashMap();
while (enumeration.hasMoreElements()) {
String paramName = (String) enumeration.nextElement();
String paramValue = config.getInitParameter(paramName);
if (paramName.startsWith("excludePattern")) {
// compile the pattern only this once
Pattern excludePattern = Pattern.compile(paramValue);
excludePatterns.add(excludePattern);
} else if (paramName.startsWith("maxWaitMilliseconds.")) {
// the delay gets parsed from the parameter name
String durationString = paramName
.substring("maxWaitMilliseconds.".length());
int endDuration = durationString.indexOf('.');
if (endDuration != -1) {
durationString = durationString.substring(0, endDuration);
}
Long duration = new Long(durationString);
// compile the corresponding pattern, and store it with this
// delay in the map
Pattern waitPattern = Pattern.compile(paramValue);
maxWaitDurations.put(waitPattern, duration);
}
}
}
public void destroy() {
}
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) request;
HttpSession session = httpRequest.getSession();
// if this request is excluded from the filter, then just process it
if (!isFilteredRequest(httpRequest)) {
chain.doFilter(request, response);
return;
}
synchronized (getSynchronizationObject(session)) {
if (isRequestInProcess(session)) {
// Put this request in the queue and wait
enqueueRequest(httpRequest);
if (!waitForRelease(httpRequest)) {
return;
}
}
// lock the session, so that no other requests are processed until
// this one finishes
setRequestInProgress(httpRequest);
}
// process this request, and then release the session lock regardless of
// any exceptions thrown farther down the chain.
try {
chain.doFilter(request, response);
} finally {
releaseQueuedRequest(httpRequest);
}
}
private static synchronized Object getSynchronizationObject(HttpSession session) {
// get the object from the session. If it does not yet exist,
// then create one.
Object syncObj = session.getAttribute(SYNC_OBJECT_KEY);
if (syncObj == null) {
syncObj = new Object();
session.setAttribute(SYNC_OBJECT_KEY, syncObj);
}
return syncObj;
}
private void setRequestInProgress(HttpServletRequest request) {
HttpSession session = request.getSession();
session.setAttribute(REQUEST_IN_PROCESS, request);
}
private void releaseQueuedRequest(HttpServletRequest request) {
HttpSession session = request.getSession();
synchronized (getSynchronizationObject(session)) {
if (session.getAttribute(REQUEST_IN_PROCESS) == request) {
session.removeAttribute(REQUEST_IN_PROCESS);
getSynchronizationObject(session).notify();
}
}
}
private boolean isRequestInProcess(HttpSession session) {
return session.getAttribute(REQUEST_IN_PROCESS) != null;
}
private boolean waitForRelease(HttpServletRequest request) {
HttpSession session = request.getSession();
// wait for the currently running request to finish, or until this
// thread has waited the maximum amount of time
try {
getSynchronizationObject(session).wait(getMaxWaitTime(request));
} catch (InterruptedException ie) {
return false;
}
// This request can be processed now if it hasn't been replaced
// in the queue
return request == session.getAttribute(REQUEST_QUEUE);
}
private void enqueueRequest(HttpServletRequest request) {
HttpSession session = request.getSession();
// Put this request in the queue, replacing whoever was there before
session.setAttribute(REQUEST_QUEUE, request);
// if another request was waiting, notify it so it can discover that
// it was replaced
getSynchronizationObject(session).notify();
}
private long getMaxWaitTime(HttpServletRequest request) {
// look for a Pattern that matches the request's path
String path = request.getRequestURI();
Iterator patternIter = maxWaitDurations.keySet().iterator();
while (patternIter.hasNext()) {
Pattern p = (Pattern) patternIter.next();
Matcher m = p.matcher(path);
if (m.matches()) {
// this pattern matches. At most, how long can this request
// wait?
Long maxDuration = (Long) maxWaitDurations.get(p);
return maxDuration.longValue();
}
}
// If no pattern matches the path, return the default value
return DEFAULT_DURATION;
}
private boolean isFilteredRequest(HttpServletRequest request) {
// iterate through the exclude patterns. If one matches this path,
// then the request is excluded.
String path = request.getRequestURI();
Iterator patternIter = excludePatterns.iterator();
while (patternIter.hasNext()) {
Pattern p = (Pattern) patternIter.next();
Matcher m = p.matcher(path);
if (m.matches()) {
return false;
}
}
// this path is not excluded
return true;
}
/** A list of Pattern objects that match paths to exclude */
private LinkedList excludePatterns;
/** A map from Pattern to max wait duration (Long objects) */
private HashMap maxWaitDurations;
/** The session attribute key for the request currently being processed */
private final static String REQUEST_IN_PROCESS = "RequestControlFilter.requestInProcess";
/** The session attribute key for the request currently waiting in the queue */
private final static String REQUEST_QUEUE = "RequestControlFilter.requestQueue";
/** The session attribute key for the synchronization object */
private final static String SYNC_OBJECT_KEY = "RequestControlFilter.sessionSync";
/** The default maximum number of milliseconds to wait for a request */
private final static long DEFAULT_DURATION = 5000;
}