package com.wesabe.servlet;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.RequestDispatcher;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.wesabe.servlet.normalizers.*;
public class SafeRequest extends HttpServletRequestWrapper {
private static final String REQUEST_DISPATCHER_PATH_PREFIX = "WEB-INF";
private static final MethodNormalizer METHOD_NORMALIZER = new MethodNormalizer();
private static final SchemeNormalizer SCHEME_NORMALIZER = new SchemeNormalizer();
private static final PortNormalizer PORT_NORMALIZER = new PortNormalizer();
private static final HostnameNormalizer HOSTNAME_NORMALIZER = new HostnameNormalizer();
private static final HeaderNameNormalizer HEADER_NAME_NORMALIZER = new HeaderNameNormalizer();
private static final HeaderValueNormalizer HEADER_VALUE_NORMALIZER = new HeaderValueNormalizer();
private static final CookieNormalizer COOKIE_NORMALIZER = new CookieNormalizer();
private static final UriNormalizer URI_NORMALIZER = new UriNormalizer();
private static final QueryStringNormalizer QUERY_STRING_NORMALIZER = new QueryStringNormalizer();
private static final ParameterNameNormalizer PARAM_NAME_NORMALIZER = new ParameterNameNormalizer();
private static final ParameterValueNormalizer PARAM_VALUE_NORMALIZER = new ParameterValueNormalizer();
private static final SessionIdNormalizer SESSION_ID_NORMALIZER = new SessionIdNormalizer();
private static final PathInfoNormalizer PATH_INFO_NORMALIZER = new PathInfoNormalizer();
private static final ContextPathNormalizer CONTEXT_PATH_NORMALIZER = new ContextPathNormalizer();
private final HttpServletRequest request;
public SafeRequest(HttpServletRequest request) {
super(request);
this.request = request;
}
@Override
public String getContextPath() {
try {
return CONTEXT_PATH_NORMALIZER.normalize(super.getContextPath());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public Cookie[] getCookies() {
try {
final Cookie[] cookies = super.getCookies();
if (cookies == null) {
return new Cookie[0];
}
for (int i = 0; i < cookies.length; i++) {
cookies[i] = COOKIE_NORMALIZER.normalize(cookies[i]);
}
return cookies;
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public long getDateHeader(String name) {
try {
return super.getDateHeader(name);
} catch (IllegalArgumentException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getHeader(String name) {
try {
final String validName = getValidHeaderName(name);
return HEADER_VALUE_NORMALIZER.normalize(super.getHeader(validName));
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public Enumeration<String> getHeaderNames() {
try {
final List<String> names = Lists.newLinkedList();
final Enumeration<?> rawNames = super.getHeaderNames();
while (rawNames.hasMoreElements()) {
names.add(HEADER_NAME_NORMALIZER.normalize((String) rawNames.nextElement()));
}
return Collections.enumeration(names);
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public Enumeration<String> getHeaders(String name) {
try {
final List<String> values = Lists.newLinkedList();
final String validName = getValidHeaderName(name);
final Enumeration<?> rawValues = super.getHeaders(validName);
while (rawValues.hasMoreElements()) {
String rawValue = (String) rawValues.nextElement();
values.add(HEADER_VALUE_NORMALIZER.normalize(rawValue));
}
return Collections.enumeration(values);
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
private String getValidHeaderName(String name) {
try {
return HEADER_NAME_NORMALIZER.normalize(name);
} catch (ValidationException e) {
throw new IllegalArgumentException(e);
}
}
@Override
public int getIntHeader(String name) {
try {
return super.getIntHeader(name);
} catch (IllegalArgumentException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getMethod() {
try {
return METHOD_NORMALIZER.normalize(super.getMethod());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getParameter(String name) {
final String validName = getValidParameterName(name);
try {
return PARAM_VALUE_NORMALIZER.normalize(super.getParameter(validName));
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
private String getValidParameterName(String name) {
try {
return PARAM_NAME_NORMALIZER.normalize(name);
} catch (ValidationException e) {
throw new IllegalArgumentException(e);
}
}
@Override
public Map<String, String[]> getParameterMap() {
try {
final Map<?, ?> rawMap = super.getParameterMap();
final Map<String, String[]> map = Maps.newLinkedHashMap();
for (Entry<?, ?> parameter : rawMap.entrySet()) {
final String validName = PARAM_NAME_NORMALIZER.normalize((String) parameter.getKey());
final String[] values = (String[]) parameter.getValue();
for (int i = 0; i < values.length; i++) {
values[i] = PARAM_VALUE_NORMALIZER.normalize(values[i]);
}
map.put(validName, values);
}
return map;
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public Enumeration<String> getParameterNames() {
try {
final List<String> names = Lists.newLinkedList();
final Enumeration<?> rawNames = super.getParameterNames();
while (rawNames.hasMoreElements()) {
names.add(PARAM_NAME_NORMALIZER.normalize((String) rawNames.nextElement()));
}
return Collections.enumeration(names);
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String[] getParameterValues(String name) {
try {
final String[] values = super.getParameterValues(getValidParameterName(name));
for (int i = 0; i < values.length; i++) {
values[i] = PARAM_VALUE_NORMALIZER.normalize(values[i]);
}
return values;
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getPathInfo() {
try {
return PATH_INFO_NORMALIZER.normalize(super.getPathInfo());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getQueryString() {
try {
return QUERY_STRING_NORMALIZER.normalize(super.getQueryString());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
if (path.startsWith(REQUEST_DISPATCHER_PATH_PREFIX)) {
return request.getRequestDispatcher(path);
}
return null;
}
@Override
public String getRequestedSessionId() {
try {
return SESSION_ID_NORMALIZER.normalize(super.getRequestedSessionId());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getRequestURI() {
try {
return URI_NORMALIZER.normalize(super.getRequestURI());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public StringBuffer getRequestURL() {
// REVIEW coda@wesabe.com -- Apr 13, 2009: Figure out how best to filter HttpServletRequest#getRequestURL().
// ESAPI just punts on the issue -- should we automatically assemble it ourselves?
return super.getRequestURL();
}
@Override
public String getScheme() {
try {
return SCHEME_NORMALIZER.normalize(super.getScheme());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getServerName() {
try {
return HOSTNAME_NORMALIZER.normalize(super.getServerName());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public int getServerPort() {
try {
return PORT_NORMALIZER.normalize(super.getServerPort());
} catch (ValidationException e) {
throw new BadRequestException(request, e);
}
}
@Override
public String getServletPath() {
// REVIEW coda@wesabe.com -- Apr 6, 2009: Figure out what servlet path normalization means
return super.getServletPath();
}
@Override
public String toString() {
return request.toString();
}
}