/* * Copyright (c) 2002-2012 Alibaba Group Holding Limited. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.alibaba.citrus.webx.servlet; import static com.alibaba.citrus.util.Assert.*; import static com.alibaba.citrus.util.CollectionUtil.*; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import java.util.Enumeration; import java.util.Set; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import com.alibaba.citrus.util.internal.Servlet3Util.Servlet3OutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeansException; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.PropertyAccessorFactory; import org.springframework.beans.PropertyValue; import org.springframework.beans.PropertyValues; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceEditor; import org.springframework.core.io.ResourceLoader; import org.springframework.web.context.support.ServletContextResourceLoader; import org.springframework.web.servlet.HttpServletBean; /** * 支持注入参数的filter基类。 * <p> * 将init-params注入到filter中,将request和response转换成 <code>HttpServletRequest</code>和 * <code>HttpServletResponse</code>。 * </p> * * @author Michael Zhou * @see HttpServletBean */ public abstract class FilterBean implements Filter { protected final Logger log = LoggerFactory.getLogger(getClass()); private final Set<String> requiredProperties = createHashSet(); private FilterConfig filterConfig; /** 添加一个必选配置项。 */ protected final void addRequiredProperty(String name) { this.requiredProperties.add(name); } /** 初始化filter。 */ public final void init(FilterConfig filterConfig) throws ServletException { this.filterConfig = filterConfig; logInBothServletAndLoggingSystem("Initializing filter: " + getFilterName()); try { PropertyValues pvs = new FilterConfigPropertyValues(getFilterConfig(), requiredProperties); BeanWrapper bw = PropertyAccessorFactory.forBeanPropertyAccess(this); ResourceLoader resourceLoader = new ServletContextResourceLoader(getServletContext()); bw.registerCustomEditor(Resource.class, new ResourceEditor(resourceLoader)); initBeanWrapper(bw); bw.setPropertyValues(pvs, true); } catch (Exception e) { throw new ServletException("Failed to set bean properties on filter: " + getFilterName(), e); } try { init(); } catch (Exception e) { throw new ServletException("Failed to init filter: " + getFilterName(), e); } logInBothServletAndLoggingSystem(getClass().getSimpleName() + " - " + getFilterName() + ": initialization completed"); } protected final void logInBothServletAndLoggingSystem(String msg) { getServletContext().log(msg); log.info(msg); } /** 初始化<code>BeanWrapper</code>。 */ protected void initBeanWrapper(BeanWrapper bw) throws BeansException { } /** 初始化Filter。 */ protected void init() throws Exception { } /** 清理filter。 */ public void destroy() { } /** 取得filter的配置。 */ public final FilterConfig getFilterConfig() { return filterConfig; } /** 取得定义在web.xml中的filter名字。 */ public final String getFilterName() { return filterConfig == null ? null : filterConfig.getFilterName(); } /** 取得当前webapp的context。 */ public final ServletContext getServletContext() { return filterConfig == null ? null : filterConfig.getServletContext(); } public final void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) { HttpServletRequest httpRequest = (HttpServletRequest) request; HttpServletResponse httpResponse = (HttpServletResponse) response; String method = httpRequest.getMethod(); if ("HEAD".equalsIgnoreCase(method)) { httpResponse = new NoBodyResponse(httpResponse); } try { doFilter(httpRequest, httpResponse, chain); } finally { if (httpResponse instanceof NoBodyResponse) { ((NoBodyResponse) httpResponse).setContentLength(); } } } else { log.debug("Skipped filtering due to the unknown request/response types: {}, {}", request.getClass() .getName(), response.getClass().getName()); chain.doFilter(request, response); } } protected abstract void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException; /** 从filter config中取得所有的init-params。 */ private static class FilterConfigPropertyValues extends MutablePropertyValues { private static final long serialVersionUID = -5359131251714023794L; public FilterConfigPropertyValues(FilterConfig config, Set<String> requiredProperties) throws ServletException { Set<String> missingProps = createTreeSet(requiredProperties); for (Enumeration<?> e = config.getInitParameterNames(); e.hasMoreElements(); ) { String key = (String) e.nextElement(); String value = config.getInitParameter(key); addPropertyValue(new PropertyValue(key, value)); missingProps.remove(key); } assertTrue(missingProps.isEmpty(), "Initialization for filter %s failed. " + "The following required properties were missing: %s", config.getFilterName(), missingProps); } } /** 不返回response body的response实现。 */ private static class NoBodyResponse extends HttpServletResponseWrapper { private NoBodyOutputStream noBody = new NoBodyOutputStream(); private PrintWriter writer; private boolean didSetContentLength; public NoBodyResponse(HttpServletResponse response) { super(response); } public void setContentLength() { if (!didSetContentLength) { super.setContentLength(noBody.getContentLength()); } } @Override public void setContentLength(int len) { super.setContentLength(len); didSetContentLength = true; } @Override public ServletOutputStream getOutputStream() throws IOException { return noBody; } @Override public PrintWriter getWriter() throws UnsupportedEncodingException { if (writer == null) { writer = new PrintWriter(new OutputStreamWriter(noBody, getCharacterEncoding())); } return writer; } } /** 不返回response body的servlet output stream实现。 */ private static class NoBodyOutputStream extends Servlet3OutputStream { private int contentLength; public NoBodyOutputStream() { super(null); contentLength = 0; } public int getContentLength() { return contentLength; } @Override public void write(int b) { contentLength++; } @Override public void write(byte[] buf, int offset, int len) throws IOException { if (len >= 0) { contentLength += len; } else { throw new IOException("negative length"); } } } }