/* * Copyright (c) 2002-2012 Alibaba Group Holding Limited. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.alibaba.citrus.util.internal; import static com.alibaba.citrus.util.Assert.*; import static com.alibaba.citrus.util.CollectionUtil.*; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.EventListener; import java.util.List; import javax.servlet.ServletOutputStream; import javax.servlet.WriteListener; import javax.servlet.http.HttpServletRequest; import net.sf.cglib.reflect.FastClass; import net.sf.cglib.reflect.FastMethod; /** * Servlet 3.0 Support - 即使在非servlet 3.0的环境中,也不会出错。 * 此类会引用如下几个Servlet 3.0的接口。在Servlet 2.5的环境中,接口由<code>citrus-common-servlet</code>项目提供。 * <ul> * <li><code>WriteListener</code></li> * </ul> * * @author Michael Zhou */ public class Servlet3Util { public static final Enum<?> DISPATCHER_TYPE_FORWARD = getEnum("javax.servlet.DispatcherType", "FORWARD"); public static final Enum<?> DISPATCHER_TYPE_INCLUDE = getEnum("javax.servlet.DispatcherType", "INCLUDE"); public static final Enum<?> DISPATCHER_TYPE_REQUEST = getEnum("javax.servlet.DispatcherType", "REQUEST"); public static final Enum<?> DISPATCHER_TYPE_ASYNC = getEnum("javax.servlet.DispatcherType", "ASYNC"); public static final Enum<?> DISPATCHER_TYPE_ERROR = getEnum("javax.servlet.DispatcherType", "ERROR"); public static final Class<?> asyncContextClass = loadClass("javax.servlet.AsyncContext"); public static final Class<?> asyncListenerClass = loadClass("javax.servlet.AsyncListener"); public static final Class<?> asyncEventClass = loadClass("javax.servlet.AsyncEvent"); public static final Class<?> writeListenerClass = loadClass("javax.servlet.WriteListener"); private static final InterfaceImplementorBuilder asyncListenerBuilder; private static final MethodInfo[] methods; private static final int request_isAsyncStarted; private static final int request_getAsyncContext; private static final int request_getDispatcherType; private static final int asyncContext_addListener; private static final int asyncEvent_getAsyncContext; private static final int servletOutputStream_isReady; private static final int servletOutputStream_setWriteListener; private static boolean disableServlet3Features = false; private static final boolean servlet3; static { List<MethodInfo> methodList = createLinkedList(); int count = 0; methodList.add(new MethodInfo(Boolean.class, false, HttpServletRequest.class, "isAsyncStarted")); request_isAsyncStarted = count++; methodList.add(new MethodInfo(Object.class, null, HttpServletRequest.class, "getAsyncContext")); request_getAsyncContext = count++; methodList.add(new MethodInfo(Enum.class, null, HttpServletRequest.class, "getDispatcherType")); request_getDispatcherType = count++; methodList.add(new MethodInfo(null, null, asyncContextClass, "addListener", asyncListenerClass)); asyncContext_addListener = count++; methodList.add(new MethodInfo(asyncContextClass, null, asyncEventClass, "getAsyncContext")); asyncEvent_getAsyncContext = count++; methodList.add(new MethodInfo(Boolean.class, true, ServletOutputStream.class, "isReady")); servletOutputStream_isReady = count++; // 这里不能硬编码WriteListener.class,否则在servlet 3.0环境中会失败。 if (writeListenerClass == null) { methodList.add(new MethodInfo(null, null)); } else { methodList.add(new MethodInfo(null, null, ServletOutputStream.class, "setWriteListener", writeListenerClass)); } servletOutputStream_setWriteListener = count++; methods = methodList.toArray(new MethodInfo[methodList.size()]); servlet3 = !methods[request_getAsyncContext].isDisabled(); asyncListenerBuilder = asyncListenerClass == null ? null : new InterfaceImplementorBuilder().addInterface(asyncListenerClass).setOverriderClass(MyAsyncListener.class).init(); } public static boolean isServlet3() { return servlet3; } /** * 设置强制禁用servlet 3.0特性。 * 有一种情况:当用httpunit测试时,虽然存在servlet 3.0的API包,但是由于httpunit未实现servlet 3.0而报错。 * 在这种情况下,可强制禁用servlet 3.0,让测试通过。 */ public static boolean setDisableServlet3Features(boolean disabled) { boolean originalValue = disableServlet3Features; disableServlet3Features = disabled; return originalValue; } public static boolean request_isAsyncStarted(HttpServletRequest request) { return (Boolean) invoke(request_isAsyncStarted, request); } public static Object /* AsyncContext */ request_getAsyncContext(HttpServletRequest request) { int index = request_getAsyncContext; if (methods[index].isDisabled()) { throw new IllegalStateException("request.getAsyncContext"); } return invoke(index, request); } public static boolean request_isDispatcherType(HttpServletRequest request, Enum<?> type) { Enum<?> dispatcherType = (Enum<?>) invoke(request_getDispatcherType, request); if (dispatcherType == null || type == null) { return false; // unsupported } else { return dispatcherType == type; } } public static Object /* AsyncContext */ asyncEvent_getAsyncContext(Object /* AsyncEvent */ event) { return invoke(asyncEvent_getAsyncContext, event); } public static void asyncContext_addAsyncListener(Object /* AsyncContext */ asyncContext, Object /* AsyncListener */ listener) { invoke(asyncContext_addListener, asyncContext, listener); } public static void request_registerAsyncListener(HttpServletRequest request, MyAsyncListener listenerImpl) { Object /* AsyncContext */ asyncContext = request_getAsyncContext(request); Object listener = assertNotNull(asyncListenerBuilder, "asyncListenerBuilder").toObject(listenerImpl); // builder should not be null invoke(asyncContext_addListener, asyncContext, listener); } private static Class<?> loadClass(String className) { try { return Servlet3Util.class.getClassLoader().loadClass(className); } catch (ClassNotFoundException e) { return null; } } private static Object invoke(int methodIndex, Object target, Object... args) { MethodInfo method = methods[methodIndex]; if (method.isDisabled()) { return method.defaultReturnValue; } try { if (method.returnValueType == null) { method.method.invoke(target, args); return null; } else { return method.returnValueType.cast(method.method.invoke(target, args)); } } catch (InvocationTargetException e) { Throwable t = e.getTargetException(); if (t instanceof RuntimeException) { throw (RuntimeException) t; } else if (t instanceof Error) { throw (Error) t; } else { throw new RuntimeException(t); } } } private static Enum<?> getEnum(String className, String name) { Class<?> enumClass = null; try { enumClass = Servlet3Util.class.getClassLoader().loadClass(className); } catch (ClassNotFoundException e) { return null; } assertTrue(Enum.class.isAssignableFrom(enumClass), "%s is not a enum class", enumClass.getName()); try { return (Enum<?>) enumClass.getField(name).get(null); } catch (Exception e) { unexpectedException(e); return null; } } private static class MethodInfo { private final FastMethod method; private final Object defaultReturnValue; private final Class<?> returnValueType; // 创建一个空的method public <T> MethodInfo(Class<T> returnValueType, T defaultReturnValue) { this(returnValueType, defaultReturnValue, null, null, (Class<?>[]) null); } public <T> MethodInfo(Class<T> returnValueType, T defaultReturnValue, Class<?> declaringClass, String methodName, Class<?>... parameterTypes) { this.returnValueType = returnValueType; this.defaultReturnValue = defaultReturnValue; Method javaMethod; FastMethod method = null; if (declaringClass != null) { try { javaMethod = declaringClass.getMethod(methodName, parameterTypes); method = FastClass.create(getClass().getClassLoader(), declaringClass).getMethod(javaMethod); } catch (NoSuchMethodException e) { } } this.method = method; } public boolean isDisabled() { return disableServlet3Features || method == null; } } public interface MyAsyncListener extends EventListener { void onComplete(Object /* AsyncEvent */ event) throws IOException; void onTimeout(Object /* AsyncEvent */ event) throws IOException; void onError(Object /* AsyncEvent */ event) throws IOException; void onStartAsync(Object /* AsyncEvent */ event) throws IOException; } /** 一个可同时在servlet 3.0和servlet 2.5环境下使用的基类。 */ public static abstract class Servlet3OutputStream extends ServletOutputStream { protected final ServletOutputStream originalStream; public Servlet3OutputStream(ServletOutputStream originalStream) { this.originalStream = originalStream; } // @Override public boolean isReady() { if (originalStream != null) { return (Boolean) invoke(servletOutputStream_isReady, originalStream); } return true; } // @Override public void setWriteListener(WriteListener writeListener) { if (originalStream != null) { invoke(servletOutputStream_setWriteListener, originalStream, writeListener); } } } }