package org.nutz.mvc.view;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletResponse;
import org.nutz.castor.Castors;
import org.nutz.http.Http;
import org.nutz.lang.Encoding;
import org.nutz.lang.Lang;
import org.nutz.lang.Streams;
import org.nutz.lang.Strings;
import org.nutz.log.Log;
import org.nutz.log.Logs;
public class HttpServerResponse implements Cloneable {
private static final Log log = Logs.get();
private int statusCode;
private String statusText;
private Map<String, String> header;
private byte[] body;
public HttpServerResponse() {
this.header = new HashMap<String, String>();
}
public HttpServerResponse clone() {
HttpServerResponse re = new HttpServerResponse();
re.statusCode = statusCode;
re.statusText = statusText;
re.header = new HashMap<String, String>();
if (header != null)
re.header.putAll(header);
re.body = body;
return re;
}
private static final Pattern _P = Pattern.compile("^HTTP/1.\\d\\s+(\\d+)(\\s+(.*))?$");
public void updateBy(String str) {
try {
// 如果以 HTTP/1.x 开头,则认为是要输出 HTTP 头
if (str.startsWith("HTTP/1.")) {
int pos = str.indexOf('\n');
// 读取返回码
String sStatus = str.substring(0, pos);
Matcher m = _P.matcher(sStatus);
if (!m.find())
throw Lang.makeThrow("invalid HTTP status line: %s", sStatus);
statusCode = Integer.parseInt(m.group(1));
statusText = Strings.trim(m.group(3));
if (Strings.isBlank(statusText))
statusText = Http.getStatusText(statusCode);
// 读取头部信息
pos++;
int end;
while ((end = str.indexOf('\n', pos)) > pos) {
String line = str.substring(pos, end);
// 拆分一下行
int p2 = line.indexOf(':');
String key = Strings.trim(line.substring(0, p2));
String val = Strings.trim(line.substring(p2 + 1));
header.put(key, val);
// 指向下一行
pos = end + 1;
}
// 头部一定读取结束了,向下跳一行
pos++;
// 读取剩余作为 body
if (pos < str.length()) {
this.body = str.substring(pos).getBytes(Encoding.UTF8);
}
}
// 否则就认为是 HTTP 200
else {
if (statusCode <= 0) {
this.updateCode(200, null);
}
this.body = str.getBytes(Encoding.UTF8);
}
}
catch (UnsupportedEncodingException e) {
throw Lang.wrapThrow(e);
}
}
public void update(Map<?, ?> map) {
for (Map.Entry<?, ?> en : map.entrySet()) {
String key = en.getKey().toString();
Object val = en.getValue();
if (null == val)
continue;
// statusCode
if ("statusCode".equals(key)) {
this.statusCode = Castors.me().castTo(val, Integer.class);
this.statusText = Http.getStatusText(statusCode);
}
// statusText
else if ("statusText".equals(key)) {
this.statusText = val.toString();
}
// body
else if ("body".equals(key)) {
try {
body = val.toString().getBytes(Encoding.UTF8);
}
catch (UnsupportedEncodingException e) {
throw Lang.wrapThrow(e);
}
}
// 其他作为 Header
else {
this.header.put(key.toUpperCase(), val.toString());
}
}
}
public void updateCode(int statusCode, String statusText) {
this.statusCode = statusCode;
this.statusText = Strings.sNull(statusText, Http.getStatusText(statusCode));
}
public void updateBody(String body) {
if (!Strings.isBlank(body))
try {
this.body = body.getBytes(Encoding.UTF8);
}
catch (UnsupportedEncodingException e) {
throw Lang.wrapThrow(e);
}
}
public void render(HttpServletResponse resp) {
resp.setStatus(statusCode);
// 标记是否需要sendError
boolean flag = statusCode >= 400;
if (null != header && header.size() > 0) {
for (Map.Entry<String, String> en : header.entrySet()) {
resp.setHeader(en.getKey(), en.getValue());
}
flag = false;
}
if (body != null) {
resp.setContentLength(body.length);
OutputStream out;
try {
out = resp.getOutputStream();
}
catch (IOException e) {
throw Lang.wrapThrow(e);
}
Streams.writeAndClose(out, body);
flag = false;
}
if (flag) {
try {
resp.sendError(statusCode);
}
catch (IOException e) {
log.debugf("sendError(%d) failed -- %s", statusCode, e.getMessage());
}
}
}
}