package com.hwlcn.security.web.util;
import com.hwlcn.security.SecurityUtils;
import com.hwlcn.security.session.Session;
import com.hwlcn.security.subject.Subject;
import com.hwlcn.security.subject.support.DefaultSubjectContext;
import com.hwlcn.security.util.StringUtils;
import com.hwlcn.security.web.filter.AccessControlFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.Map;
public class WebUtils {
private static final Logger log = LoggerFactory.getLogger(WebUtils.class);
public static final String SERVLET_REQUEST_KEY = ServletRequest.class.getName() + "_SECURITY_THREAD_CONTEXT_KEY";
public static final String SERVLET_RESPONSE_KEY = ServletResponse.class.getName() + "_SECURITY_THREAD_CONTEXT_KEY";
public static final String SAVED_REQUEST_KEY = "securitySavedRequest";
//servlet 2.3+处理
public static final String INCLUDE_REQUEST_URI_ATTRIBUTE = "javax.servlet.include.request_uri";
public static final String INCLUDE_CONTEXT_PATH_ATTRIBUTE = "javax.servlet.include.context_path";
public static final String INCLUDE_SERVLET_PATH_ATTRIBUTE = "javax.servlet.include.servlet_path";
public static final String INCLUDE_PATH_INFO_ATTRIBUTE = "javax.servlet.include.path_info";
public static final String INCLUDE_QUERY_STRING_ATTRIBUTE = "javax.servlet.include.query_string";
//servlet 2.4+ 处理
public static final String FORWARD_REQUEST_URI_ATTRIBUTE = "javax.servlet.forward.request_uri";
public static final String FORWARD_CONTEXT_PATH_ATTRIBUTE = "javax.servlet.forward.context_path";
public static final String FORWARD_SERVLET_PATH_ATTRIBUTE = "javax.servlet.forward.servlet_path";
public static final String FORWARD_PATH_INFO_ATTRIBUTE = "javax.servlet.forward.path_info";
public static final String FORWARD_QUERY_STRING_ATTRIBUTE = "javax.servlet.forward.query_string";
//默认编码为 UTF-8 编码
public static final String DEFAULT_CHARACTER_ENCODING = "UTF-8";
public static String getPathWithinApplication(HttpServletRequest request) {
String contextPath = getContextPath(request);
String requestUri = getRequestUri(request);
if (StringUtils.startsWithIgnoreCase(requestUri, contextPath)) {
String path = requestUri.substring(contextPath.length());
return (StringUtils.hasText(path) ? path : "/");
} else {
return requestUri;
}
}
public static String getRequestUri(HttpServletRequest request) {
//为什么这样取呢?
String uri = (String) request.getAttribute(INCLUDE_REQUEST_URI_ATTRIBUTE);
if (uri == null) {
uri = request.getRequestURI();
}
return normalize(decodeAndCleanUriString(request, uri));
}
public static String normalize(String path) {
return normalize(path, true);
}
private static String normalize(String path, boolean replaceBackSlash) {
if (path == null)
return null;
String normalized = path;
if (replaceBackSlash && normalized.indexOf('\\') >= 0)
normalized = normalized.replace('\\', '/');
if (normalized.equals("/."))
return "/";
if (!normalized.startsWith("/"))
normalized = "/" + normalized;
while (true) {
int index = normalized.indexOf("//");
if (index < 0)
break;
normalized = normalized.substring(0, index) +
normalized.substring(index + 1);
}
while (true) {
int index = normalized.indexOf("/./");
if (index < 0)
break;
normalized = normalized.substring(0, index) +
normalized.substring(index + 2);
}
while (true) {
int index = normalized.indexOf("/../");
if (index < 0)
break;
if (index == 0)
return (null);
int index2 = normalized.lastIndexOf('/', index - 1);
normalized = normalized.substring(0, index2) +
normalized.substring(index + 3);
}
return (normalized);
}
private static String decodeAndCleanUriString(HttpServletRequest request, String uri) {
uri = decodeRequestString(request, uri);
int semicolonIndex = uri.indexOf(';');
return (semicolonIndex != -1 ? uri.substring(0, semicolonIndex) : uri);
}
public static String getContextPath(HttpServletRequest request) {
String contextPath = (String) request.getAttribute(INCLUDE_CONTEXT_PATH_ATTRIBUTE);
if (contextPath == null) {
contextPath = request.getContextPath();
}
if ("/".equals(contextPath)) {
contextPath = "";
}
return decodeRequestString(request, contextPath);
}
@SuppressWarnings({"deprecation"})
public static String decodeRequestString(HttpServletRequest request, String source) {
String enc = determineEncoding(request);
try {
return URLDecoder.decode(source, enc);
} catch (UnsupportedEncodingException ex) {
if (log.isWarnEnabled()) {
log.warn("Could not decode request string [" + source + "] with encoding '" + enc +
"': falling back to platform default encoding; exception message: " + ex.getMessage());
}
return URLDecoder.decode(source);
}
}
protected static String determineEncoding(HttpServletRequest request) {
String enc = request.getCharacterEncoding();
if (enc == null) {
enc = DEFAULT_CHARACTER_ENCODING;
}
return enc;
}
public static boolean isWeb(Object requestPairSource) {
return requestPairSource instanceof RequestPairSource && isWeb((RequestPairSource) requestPairSource);
}
public static boolean isHttp(Object requestPairSource) {
return requestPairSource instanceof RequestPairSource && isHttp((RequestPairSource) requestPairSource);
}
public static ServletRequest getRequest(Object requestPairSource) {
if (requestPairSource instanceof RequestPairSource) {
return ((RequestPairSource) requestPairSource).getServletRequest();
}
return null;
}
public static ServletResponse getResponse(Object requestPairSource) {
if (requestPairSource instanceof RequestPairSource) {
return ((RequestPairSource) requestPairSource).getServletResponse();
}
return null;
}
public static HttpServletRequest getHttpRequest(Object requestPairSource) {
ServletRequest request = getRequest(requestPairSource);
if (request instanceof HttpServletRequest) {
return (HttpServletRequest) request;
}
return null;
}
public static HttpServletResponse getHttpResponse(Object requestPairSource) {
ServletResponse response = getResponse(requestPairSource);
if (response instanceof HttpServletResponse) {
return (HttpServletResponse) response;
}
return null;
}
private static boolean isWeb(RequestPairSource source) {
ServletRequest request = source.getServletRequest();
ServletResponse response = source.getServletResponse();
return request != null && response != null;
}
private static boolean isHttp(RequestPairSource source) {
ServletRequest request = source.getServletRequest();
ServletResponse response = source.getServletResponse();
return request instanceof HttpServletRequest && response instanceof HttpServletResponse;
}
public static boolean _isSessionCreationEnabled(Object requestPairSource) {
if (requestPairSource instanceof RequestPairSource) {
RequestPairSource source = (RequestPairSource) requestPairSource;
return _isSessionCreationEnabled(source.getServletRequest());
}
return true;
}
public static boolean _isSessionCreationEnabled(ServletRequest request) {
if (request != null) {
Object val = request.getAttribute(DefaultSubjectContext.SESSION_CREATION_ENABLED);
if (val != null && val instanceof Boolean) {
return (Boolean) val;
}
}
return true;
}
public static HttpServletRequest toHttp(ServletRequest request) {
return (HttpServletRequest) request;
}
public static HttpServletResponse toHttp(ServletResponse response) {
return (HttpServletResponse) response;
}
public static void issueRedirect(ServletRequest request, ServletResponse response, String url, Map queryParams, boolean contextRelative, boolean http10Compatible) throws IOException {
RedirectView view = new RedirectView(url, contextRelative, http10Compatible);
view.renderMergedOutputModel(queryParams, toHttp(request), toHttp(response));
}
public static void issueRedirect(ServletRequest request, ServletResponse response, String url) throws IOException {
issueRedirect(request, response, url, null, true, true);
}
public static void issueRedirect(ServletRequest request, ServletResponse response, String url, Map queryParams) throws IOException {
issueRedirect(request, response, url, queryParams, true, true);
}
public static void issueRedirect(ServletRequest request, ServletResponse response, String url, Map queryParams, boolean contextRelative) throws IOException {
issueRedirect(request, response, url, queryParams, contextRelative, true);
}
public static boolean isTrue(ServletRequest request, String paramName) {
String value = getCleanParam(request, paramName);
return value != null &&
(value.equalsIgnoreCase("true") ||
value.equalsIgnoreCase("t") ||
value.equalsIgnoreCase("1") ||
value.equalsIgnoreCase("enabled") ||
value.equalsIgnoreCase("y") ||
value.equalsIgnoreCase("yes") ||
value.equalsIgnoreCase("on"));
}
public static String getCleanParam(ServletRequest request, String paramName) {
return StringUtils.clean(request.getParameter(paramName));
}
public static void saveRequest(ServletRequest request) {
Subject subject = SecurityUtils.getSubject();
Session session = subject.getSession();
HttpServletRequest httpRequest = toHttp(request);
SavedRequest savedRequest = new SavedRequest(httpRequest);
session.setAttribute(SAVED_REQUEST_KEY, savedRequest);
}
public static SavedRequest getAndClearSavedRequest(ServletRequest request) {
SavedRequest savedRequest = getSavedRequest(request);
if (savedRequest != null) {
Subject subject = SecurityUtils.getSubject();
Session session = subject.getSession();
session.removeAttribute(SAVED_REQUEST_KEY);
}
return savedRequest;
}
public static SavedRequest getSavedRequest(ServletRequest request) {
SavedRequest savedRequest = null;
Subject subject = SecurityUtils.getSubject();
Session session = subject.getSession(false);
if (session != null) {
savedRequest = (SavedRequest) session.getAttribute(SAVED_REQUEST_KEY);
}
return savedRequest;
}
public static void redirectToSavedRequest(ServletRequest request, ServletResponse response, String fallbackUrl)
throws IOException {
String successUrl = null;
boolean contextRelative = true;
SavedRequest savedRequest = WebUtils.getAndClearSavedRequest(request);
if (savedRequest != null && savedRequest.getMethod().equalsIgnoreCase(AccessControlFilter.GET_METHOD)) {
successUrl = savedRequest.getRequestUrl();
contextRelative = false;
}
if (successUrl == null) {
successUrl = fallbackUrl;
}
if (successUrl == null) {
throw new IllegalStateException("Success URL not available via saved request or via the " +
"successUrlFallback method parameter. One of these must be non-null for " +
"issueSuccessRedirect() to work.");
}
WebUtils.issueRedirect(request, response, successUrl, null, contextRelative);
}
}