/* * Copyright (c) 2002-2012 Alibaba Group Holding Limited. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.alibaba.citrus.service.requestcontext.rewrite.impl; import java.util.regex.MatchResult; import javax.servlet.http.HttpServletRequest; import com.alibaba.citrus.service.requestcontext.parser.ParserRequestContext; import com.alibaba.citrus.service.requestcontext.util.RequestContextUtil; import com.alibaba.citrus.util.regex.MatchResultSubstitution; /** * 和rewrite相关的工具类。 * * @author Michael Zhou */ public class RewriteUtil { public static boolean isFullURL(String path) { return path.matches("^\\w+:.*"); } public static MatchResultSubstitution getMatchResultSubstitution(MatchResult ruleMatchResult, MatchResult conditionMatchResult) { return new MatchResultSubstitution("$%", ruleMatchResult, conditionMatchResult); } public static String getSubstitutedTestString(String testString, MatchResult ruleMatchResult, MatchResult conditionMatchResult, HttpServletRequest request) { testString = eval(testString, request); return getMatchResultSubstitution(ruleMatchResult, conditionMatchResult).substitute(testString); } protected static String eval(String expr, HttpServletRequest request) { int length = expr.length(); int startIndex = expr.indexOf("%{"); // 如果表达式不包含%{},则直接返回之。 if (startIndex < 0) { return expr; } int endIndex = expr.indexOf("}", startIndex + 2); if (endIndex < 0) { return expr; } // 创建复合的表达式。 StringBuffer resultBuffer = new StringBuffer(); StringBuffer charBuffer = new StringBuffer(); StringBuffer varNameBuffer = new StringBuffer(); MAIN_LOOP: for (int i = 0; i < length; ) { char ch = expr.charAt(i); switch (ch) { case '%': if (i + 1 < length) { ch = expr.charAt(++i); switch (ch) { case '%': charBuffer.append(ch); break; case '{': if (charBuffer.length() > 0) { resultBuffer.append(charBuffer); charBuffer.setLength(0); } if (i + 1 < length) { ++i; while (i < length) { ch = expr.charAt(i); switch (ch) { case '}': String varName = varNameBuffer.toString(); String varNameExpanded = expandAndEscape(varName, request); if (varNameExpanded == null) { resultBuffer.append("\\%{").append(varName).append("}"); } else { resultBuffer.append(varNameExpanded); } varNameBuffer.setLength(0); ++i; continue MAIN_LOOP; default: varNameBuffer.append(ch); ++i; } } if (varNameBuffer.length() > 0) { resultBuffer.append("%{").append(varNameBuffer); } } break; default: charBuffer.append('%').append(ch); } } else { charBuffer.append(ch); } break; default: charBuffer.append(ch); } i++; } if (charBuffer.length() > 0) { resultBuffer.append(charBuffer); } return resultBuffer.toString(); } /** * 展开变量,并对其中的特殊字符进行处理,避免其被解析。 * * @return 注意,如果返回null,表示按原样显示,例如:%{XYZ} */ private static String expandAndEscape(String varName, HttpServletRequest request) { String value = expand(varName, request); if (value != null) { int length = value.length(); StringBuilder buf = new StringBuilder(length + 10); boolean changed = false; for (int i = 0; i < length; i++) { char ch = value.charAt(i); switch (ch) { case '\\': buf.append("\\\\"); changed = true; break; case '$': buf.append("\\$"); changed = true; break; case '%': buf.append("\\%"); changed = true; break; default: buf.append(ch); } } if (changed) { value = buf.toString(); } } return value; } /** * 展开变量。 * * @return 注意,如果返回null,表示按原样显示,例如:%{XYZ} */ private static String expand(String varName, HttpServletRequest request) { boolean valid = true; String result; // ===================================================== // Client side of the IP connection // ===================================================== if ("REMOTE_HOST".equals(varName)) { result = request.getRemoteHost(); } // else if ("REMOTE_ADDR".equals(varName)) { result = request.getRemoteAddr(); } // else if ("REMOTE_USER".equals(varName)) { result = request.getRemoteUser(); } // else if ("REQUEST_METHOD".equals(varName)) { result = request.getMethod(); } // else if ("QUERY_STRING".equals(varName)) { if ("post".equalsIgnoreCase(request.getMethod())) { ParserRequestContext parserRequestContext = RequestContextUtil.findRequestContext(request, ParserRequestContext.class); result = parserRequestContext.getParameters().toQueryString(); } else { result = request.getQueryString(); } } // else if (varName.startsWith("QUERY:")) { ParserRequestContext parserRequestContext = RequestContextUtil.findRequestContext(request, ParserRequestContext.class); result = parserRequestContext.getParameters().getString(varName.substring("QUERY:".length()).trim()); } // else if ("AUTH_TYPE".equals(varName)) { result = request.getAuthType(); } // ===================================================== // HTTP layer details extracted from HTTP headers // ===================================================== else if ("SERVER_NAME".equals(varName)) { result = request.getServerName(); } // else if ("SERVER_PORT".equals(varName)) { result = String.valueOf(request.getServerPort()); } // else if ("SERVER_PROTOCOL".equals(varName)) { result = request.getProtocol(); } // ===================================================== // HTTP headers // ===================================================== else if ("HTTP_USER_AGENT".equals(varName)) { result = request.getHeader("User-Agent"); } // else if ("HTTP_REFERER".equals(varName)) { result = request.getHeader("Referer"); } // else if ("HTTP_HOST".equals(varName)) { result = request.getHeader("Host"); } // else if ("HTTP_ACCEPT".equals(varName)) { result = request.getHeader("Accept"); } // else if ("HTTP_COOKIE".equals(varName)) { result = request.getHeader("Cookie"); } // ===================================================== // Others // ===================================================== else if ("REQUEST_URI".equals(varName)) { result = request.getRequestURI(); } else { result = null; valid = false; } // 如果变量合法,但值为null,则返回"" if (valid && result == null) { result = ""; } return result; } }