package com.ctriposs.baiji.rpc.server; import com.ctriposs.baiji.rpc.common.util.DaemonThreadFactory; import com.ctriposs.baiji.rpc.server.util.ConfigUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.servlet.*; import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.util.concurrent.Callable; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; /** * An async {@link javax.servlet.http.HttpServlet} that responds to Baiji RPC requests. */ public class AsyncBaijiServlet extends BaijiServletBase { private static final int DEFAULT_ASYNC_TIMEOUT = 60 * 1000; // In ms. private static final int DEFAULT_CORE_THREAD_COUNT = 20; private static final int DEFAULT_MAX_THREAD_COUNT = 100; private static final int DEFAULT_KEEP_ALIVE_TIME = 60 * 1000; // Servlet Init Params private static final String ASYNC_TIMEOUT_PARAM = "async-timeout"; private static final String CORE_THREAD_COUNT_PARAM = "core-thread-count"; private static final String MAX_THREAD_COUNT_PARAM = "max-thread-count"; private static final String KEEP_ALIVE_TIME_PARAM = "keepalive-time"; private ThreadPoolExecutor _poolExecutor; private int _asyncTimeout; private int _coreThreadCount; private int _maxThreadCount; private int _keepAliveTime; @Override public void init() { super.init(); ServletConfig servletConfig = getServletConfig(); _asyncTimeout = ConfigUtil.getIntConfig(servletConfig, ASYNC_TIMEOUT_PARAM, DEFAULT_ASYNC_TIMEOUT); _coreThreadCount = ConfigUtil.getIntConfig(servletConfig, CORE_THREAD_COUNT_PARAM, DEFAULT_CORE_THREAD_COUNT); _maxThreadCount = ConfigUtil.getIntConfig(servletConfig, MAX_THREAD_COUNT_PARAM, DEFAULT_MAX_THREAD_COUNT); _keepAliveTime = ConfigUtil.getIntConfig(servletConfig, KEEP_ALIVE_TIME_PARAM, DEFAULT_KEEP_ALIVE_TIME); _poolExecutor = new ThreadPoolExecutor(_coreThreadCount, _maxThreadCount, _keepAliveTime, TimeUnit.MILLISECONDS, new SynchronousQueue<Runnable>(), new DaemonThreadFactory()); } @Override public void service(ServletRequest req, ServletResponse resp) throws ServletException, IOException { if (_poolExecutor == null) { throw new IllegalStateException("The servlet has been destroyed."); } req.setAttribute("org.apache.catalina.ASYNC_SUPPORTED", true); AsyncContext asyncContext = req.startAsync(); asyncContext.setTimeout(_asyncTimeout); AsyncBaijiListener listener = new AsyncBaijiListener(); asyncContext.addListener(listener); try { _poolExecutor.submit(new BaijiCallable(asyncContext, listener)); } catch (RuntimeException e) { throw e; } } @Override public void destroy() { if (_poolExecutor != null) { try { _poolExecutor.awaitTermination(_asyncTimeout, TimeUnit.MILLISECONDS); _poolExecutor.shutdown(); } catch (InterruptedException e) { _poolExecutor.shutdownNow(); _poolExecutor = null; } } super.destroy(); } private class BaijiCallable implements Callable { private final AsyncContext _context; private final AsyncBaijiListener _listener; public BaijiCallable(AsyncContext context, AsyncBaijiListener listener) { _context = context; _listener = listener; } @Override public Object call() throws Exception { try { processRequest(_context.getRequest(), _context.getResponse()); } catch (Throwable t) { _logger.error("BaijiCallable execute error.", t); } finally { if (!_listener.isTimedOut()) { try { _context.complete(); } catch (Throwable t) { _logger.error("AsyncContext complete error.", t); } } } return null; } } private static class AsyncBaijiListener implements AsyncListener { private static final Logger _logger = LoggerFactory.getLogger(AsyncBaijiListener.class); private boolean _timedOut; public boolean isTimedOut() { return _timedOut; } @Override public void onComplete(AsyncEvent event) throws IOException { } @Override public void onTimeout(AsyncEvent event) throws IOException { _timedOut = true; _logger.error("Access {} timeout in AsyncBaijiServlet.", ((HttpServletRequest) event.getAsyncContext().getRequest()).getRequestURL()); } @Override public void onError(AsyncEvent event) throws IOException { _logger.error("Error while access {} in AsyncBaijiServlet.", ((HttpServletRequest) event.getAsyncContext().getRequest()).getRequestURL()); } @Override public void onStartAsync(AsyncEvent event) throws IOException { } } }