package com.googlecode.mycontainer.commons.servlet;
import java.io.BufferedReader;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.googlecode.mycontainer.commons.util.ContentUtil;
import com.googlecode.mycontainer.commons.util.JsonUtil;
public class ServletUtil {
@SuppressWarnings("unchecked")
public static Map<String, Object> getAttributes(HttpServletRequest req) {
Map<String, Object> ret = new HashMap<String, Object>();
Enumeration<String> en = req.getAttributeNames();
while (en.hasMoreElements()) {
String key = en.nextElement();
Object value = req.getAttribute((String) key);
ret.put(key, value);
}
return ret;
}
public static String getUserPath(HttpServletRequest request) {
String servetPath = request.getServletPath();
return getUserPath(request, "^" + servetPath);
}
public static String getUserPath(HttpServletRequest request, String ignore) {
String contextPath = request.getContextPath();
String requestURI = request.getRequestURI();
String starts = "";
if (!contextPath.equals("/")) {
starts += contextPath;
}
String ret = requestURI.substring(starts.length());
ret = ret.replaceAll(ignore, "");
return ret;
}
public static void checkMethods(HttpServletRequest request,
HttpServletResponse resp, String msg, String... alloweds) {
for (String allowed : alloweds) {
if (allowed.equals(request.getMethod())) {
return;
}
}
sendUnsupportedMethod(resp, msg, alloweds);
}
public static void sendUnsupportedMethod(HttpServletResponse resp,
String msg, String... alloweds) {
for (String allowed : alloweds) {
resp.addHeader("Allow", allowed);
}
sendError(resp, HttpServletResponse.SC_METHOD_NOT_ALLOWED, msg);
}
public static void sendError(HttpServletResponse resp, int code, String msg) {
try {
if (msg == null) {
resp.sendError(code);
} else {
resp.sendError(code, msg);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static String checkParameter(HttpServletRequest req,
HttpServletResponse resp, String name, String... requires) {
String value = req.getParameter(name);
if (value == null) {
value = "";
}
value = value.trim();
for (String require : requires) {
if (require.equals(value)) {
return (require.length() == 0 ? requires[0] : require);
}
}
StringBuilder sb = new StringBuilder();
sb.append("parameter required '").append(name).append("'");
if (requires.length > 0) {
sb.append(" with one of these: ").append(Arrays.toString(requires));
} else {
sb.append(" with any value");
}
sb.append(", but was: '").append(value).append("'");
throw new RuntimeException(sb.toString());
}
public static <T> List<T> getParameters(HttpServletRequest req,
String name, Class<T> clazz) {
try {
String[] values = req.getParameterValues(name);
if (values == null) {
return new ArrayList<T>();
}
ArrayList<T> ret = new ArrayList<T>(values.length);
for (String value : values) {
Constructor<T> cons = clazz
.getConstructor(new Class[] { String.class });
T parsed = cons.newInstance(value);
ret.add(parsed);
}
return ret;
} catch (SecurityException e) {
throw new RuntimeException(e);
} catch (IllegalArgumentException e) {
throw new RuntimeException(e);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings("unchecked")
public static Map<String, List<String>> getHeaders(
Map<String, List<String>> ret, HttpServletRequest request) {
if (ret == null) {
ret = new HashMap<String, List<String>>();
}
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
Enumeration<String> headers = request.getHeaders(name);
while (headers.hasMoreElements()) {
String header = headers.nextElement();
List<String> list = ret.get(name);
if (list == null) {
list = new ArrayList<String>();
ret.put(name, list);
}
list.add(header);
}
}
return null;
}
@SuppressWarnings("unchecked")
public static Map<String, List<String>> getParameters(
Map<String, List<String>> ret, HttpServletRequest request) {
if (ret == null) {
ret = new HashMap<String, List<String>>();
}
Set<Entry<String, String[]>> entries = request.getParameterMap()
.entrySet();
for (Entry<String, String[]> entry : entries) {
String key = entry.getKey();
String[] values = entry.getValue();
List<String> list = ret.get(key);
if (list == null) {
list = new ArrayList<String>();
ret.put(key, list);
}
list.addAll(Arrays.asList(values));
}
return ret;
}
public static void setHeaders(HttpServletResponse response,
Map<String, List<String>> headers) {
Set<Entry<String, List<String>>> set = headers.entrySet();
for (Entry<String, List<String>> entry : set) {
String key = entry.getKey();
List<String> values = entry.getValue();
for (String value : values) {
if (value != null) {
response.addHeader(key, value);
}
}
}
}
public static void write(HttpServletResponse response, char[] array) {
try {
response.getWriter().write(array);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static void write(HttpServletResponse response, byte[] content) {
try {
response.getOutputStream().write(content);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static JsonElement readJson(HttpServletRequest req) {
try {
checkContentType(req);
BufferedReader reader = req.getReader();
JsonElement json = JsonUtil.parse(reader);
return json;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static void checkContentType(HttpServletRequest req) {
String contentType = ContentUtil.getMediaType(req.getContentType());
if (contentType != null && !contentType.equals("application/json")) {
throw new RuntimeException("Content-Type must be application/json");
}
String charset = req.getCharacterEncoding();
if (charset == null) {
throw new RuntimeException("charset is required in Content-Type");
}
}
public static void write(HttpServletResponse resp, JsonElement obj) {
try {
if (obj instanceof JsonObject) {
JsonObject json = (JsonObject) obj;
JsonElement lastModified = json.get("_lastModified");
if (JsonUtil.t(lastModified)) {
resp.setDateHeader("Last-Modified",
lastModified.getAsLong());
}
}
resp.setContentType("application/json");
resp.setCharacterEncoding("UTF-8");
resp.getWriter().write(obj == null ? "null" : obj.toString());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static JsonElement paramJson(HttpServletRequest req, String name) {
String value = param(req, name);
return (value == null ? null : JsonUtil.parse(value));
}
public static String param(HttpServletRequest req, String name) {
String ret = req.getParameter(name);
if (ret != null) {
ret = ret.trim();
}
return (ret == null || ret.length() == 0 ? null : ret);
}
public static JsonArray paramJsons(HttpServletRequest req, String name) {
String[] values = req.getParameterValues(name);
if (values == null || values.length == 0) {
return null;
}
JsonArray ret = new JsonArray();
for (String str : values) {
ret.add(JsonUtil.parse(str));
}
return ret;
}
public static void writeJson(HttpServletResponse resp, Object value) {
write(resp, JsonUtil.createBasic(value));
}
public static String getCookie(HttpServletRequest req, String name) {
Cookie[] cookies = req.getCookies();
if (cookies == null) {
return null;
}
for (Cookie cookie : cookies) {
if (name.equals(cookie.getName())) {
return cookie.getValue();
}
}
return null;
}
}