// // ======================================================================== // Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd. // ------------------------------------------------------------------------ // All rights reserved. This program and the accompanying materials // are made available under the terms of the Eclipse Public License v1.0 // and Apache License v2.0 which accompanies this distribution. // // The Eclipse Public License is available at // http://www.eclipse.org/legal/epl-v10.html // // The Apache License v2.0 is available at // http://www.opensource.org/licenses/apache2.0.php // // You may elect to redistribute this code under either of these licenses. // ======================================================================== // package logbook.server.proxy; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.SequenceInputStream; import java.net.InetAddress; import java.net.URI; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.util.Enumeration; import java.util.HashSet; import java.util.Iterator; import java.util.Locale; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import javax.servlet.AsyncContext; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.UnavailableException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.client.HttpClient; import org.eclipse.jetty.client.api.ContentProvider; import org.eclipse.jetty.client.api.Request; import org.eclipse.jetty.client.api.Response; import org.eclipse.jetty.client.api.Result; import org.eclipse.jetty.client.util.InputStreamContentProvider; import org.eclipse.jetty.http.HttpField; import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpMethod; import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.proxy.ConnectHandler; import org.eclipse.jetty.server.handler.ContextHandler; import org.eclipse.jetty.util.HttpCookieStore; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.thread.QueuedThreadPool; /** * Asynchronous ProxyServlet. * <p/> * Forwards requests to another server either as a standard web reverse proxy * (as defined by RFC2616) or as a transparent reverse proxy. * <p/> * To facilitate JMX monitoring, the {@link HttpClient} instance is set as context attribute, * prefixed with the servlet's name and exposed by the mechanism provided by * {@link ContextHandler#MANAGED_ATTRIBUTES}. * <p/> * The following init parameters may be used to configure the servlet: * <ul> * <li>hostHeader - forces the host header to a particular value</li> * <li>viaHost - the name to use in the Via header: Via: http/1.1 <viaHost></li> * <li>whiteList - comma-separated list of allowed proxy hosts</li> * <li>blackList - comma-separated list of forbidden proxy hosts</li> * </ul> * <p/> * In addition, see {@link #createHttpClient()} for init parameters used to configure * the {@link HttpClient} instance. * * @see ConnectHandler */ public class ProxyServlet extends HttpServlet { protected static final String ASYNC_CONTEXT = ProxyServlet.class.getName() + ".asyncContext"; private static final Set<String> HOP_HEADERS = new HashSet<>(); static { HOP_HEADERS.add("proxy-connection"); HOP_HEADERS.add("connection"); HOP_HEADERS.add("keep-alive"); HOP_HEADERS.add("transfer-encoding"); HOP_HEADERS.add("te"); HOP_HEADERS.add("trailer"); HOP_HEADERS.add("proxy-authorization"); HOP_HEADERS.add("proxy-authenticate"); HOP_HEADERS.add("upgrade"); } private final Set<String> _whiteList = new HashSet<>(); private final Set<String> _blackList = new HashSet<>(); protected Logger _log; private String _hostHeader; private String _viaHost; private HttpClient _client; private long _timeout; @Override public void init() throws ServletException { this._log = this.createLogger(); ServletConfig config = this.getServletConfig(); this._hostHeader = config.getInitParameter("hostHeader"); this._viaHost = config.getInitParameter("viaHost"); if (this._viaHost == null) this._viaHost = viaHost(); try { this._client = this.createHttpClient(); // Put the HttpClient in the context to leverage ContextHandler.MANAGED_ATTRIBUTES this.getServletContext().setAttribute(config.getServletName() + ".HttpClient", this._client); String whiteList = config.getInitParameter("whiteList"); if (whiteList != null) this.getWhiteListHosts().addAll(this.parseList(whiteList)); String blackList = config.getInitParameter("blackList"); if (blackList != null) this.getBlackListHosts().addAll(this.parseList(blackList)); } catch (Exception e) { throw new ServletException(e); } } public long getTimeout() { return this._timeout; } public void setTimeout(long timeout) { this._timeout = timeout; } public Set<String> getWhiteListHosts() { return this._whiteList; } public Set<String> getBlackListHosts() { return this._blackList; } protected static String viaHost() { try { return InetAddress.getLocalHost().getHostName(); } catch (UnknownHostException x) { return "localhost"; } } /** * @return a logger instance with a name derived from this servlet's name. */ protected Logger createLogger() { String name = this.getServletConfig().getServletName(); name = name.replace('-', '.'); return Log.getLogger(name); } @Override public void destroy() { try { this._client.stop(); } catch (Exception x) { this._log.debug(x); } } /** * Creates a {@link HttpClient} instance, configured with init parameters of this servlet. * <p/> * The init parameters used to configure the {@link HttpClient} instance are: * <table> * <thead> * <tr> * <th>init-param</th> * <th>default</th> * <th>description</th> * </tr> * </thead> * <tbody> * <tr> * <td>maxThreads</td> * <td>256</td> * <td>The max number of threads of HttpClient's Executor</td> * </tr> * <tr> * <td>maxConnections</td> * <td>32768</td> * <td>The max number of connections per destination, see {@link HttpClient#setMaxConnectionsPerDestination(int)}</td> * </tr> * <tr> * <td>idleTimeout</td> * <td>30000</td> * <td>The idle timeout in milliseconds, see {@link HttpClient#setIdleTimeout(long)}</td> * </tr> * <tr> * <td>timeout</td> * <td>60000</td> * <td>The total timeout in milliseconds, see {@link Request#timeout(long, TimeUnit)}</td> * </tr> * <tr> * <td>requestBufferSize</td> * <td>HttpClient's default</td> * <td>The request buffer size, see {@link HttpClient#setRequestBufferSize(int)}</td> * </tr> * <tr> * <td>responseBufferSize</td> * <td>HttpClient's default</td> * <td>The response buffer size, see {@link HttpClient#setResponseBufferSize(int)}</td> * </tr> * </tbody> * </table> * * @return a {@link HttpClient} configured from the {@link #getServletConfig() servlet configuration} * @throws ServletException if the {@link HttpClient} cannot be created */ protected HttpClient createHttpClient() throws ServletException { ServletConfig config = this.getServletConfig(); HttpClient client = this.newHttpClient(); // Redirects must be proxied as is, not followed client.setFollowRedirects(false); // Must not store cookies, otherwise cookies of different clients will mix client.setCookieStore(new HttpCookieStore.Empty()); String value = config.getInitParameter("maxThreads"); if (value == null) value = "256"; QueuedThreadPool executor = new QueuedThreadPool(Integer.parseInt(value)); String servletName = config.getServletName(); int dot = servletName.lastIndexOf('.'); if (dot >= 0) servletName = servletName.substring(dot + 1); executor.setName(servletName); client.setExecutor(executor); value = config.getInitParameter("maxConnections"); if (value == null) value = "32768"; client.setMaxConnectionsPerDestination(Integer.parseInt(value)); value = config.getInitParameter("idleTimeout"); if (value == null) value = "30000"; client.setIdleTimeout(Long.parseLong(value)); value = config.getInitParameter("timeout"); if (value == null) value = "60000"; this._timeout = Long.parseLong(value); value = config.getInitParameter("requestBufferSize"); if (value != null) client.setRequestBufferSize(Integer.parseInt(value)); value = config.getInitParameter("responseBufferSize"); if (value != null) client.setResponseBufferSize(Integer.parseInt(value)); try { client.start(); // Content must not be decoded, otherwise the client gets confused client.getContentDecoderFactories().clear(); return client; } catch (Exception x) { throw new ServletException(x); } } /** * @return a new HttpClient instance */ protected HttpClient newHttpClient() { return new HttpClient(); } private Set<String> parseList(String list) { Set<String> result = new HashSet<>(); String[] hosts = list.split(","); for (String host : hosts) { host = host.trim(); if (host.length() == 0) continue; result.add(host); } return result; } /** * Checks the given {@code host} and {@code port} against whitelist and blacklist. * * @param host the host to check * @param port the port to check * @return true if it is allowed to be proxy to the given host and port */ public boolean validateDestination(String host, int port) { String hostPort = host + ":" + port; if (!this._whiteList.isEmpty()) { if (!this._whiteList.contains(hostPort)) { this._log.debug("Host {}:{} not whitelisted", host, port); return false; } } if (!this._blackList.isEmpty()) { if (this._blackList.contains(hostPort)) { this._log.debug("Host {}:{} blacklisted", host, port); return false; } } return true; } @Override protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { final int requestId = this.getRequestId(request); URI rewrittenURI = this.rewriteURI(request); if (this._log.isDebugEnabled()) { StringBuffer uri = request.getRequestURL(); if (request.getQueryString() != null) uri.append("?").append(request.getQueryString()); this._log.debug("{} rewriting: {} -> {}", requestId, uri, rewrittenURI); } if (rewrittenURI == null) { response.sendError(HttpServletResponse.SC_FORBIDDEN); return; } AsyncContext asyncContext = request.startAsync(); // We do not timeout the continuation, but the proxy request asyncContext.setTimeout(0); request.setAttribute(ASYNC_CONTEXT, asyncContext); ProxyRequestHandler proxyRequestHandler = new ProxyRequestHandler(request, response, rewrittenURI); proxyRequestHandler.send(); } private Request createProxyRequest(HttpServletRequest request, HttpServletResponse response, URI targetUri, ContentProvider contentProvider) { final Request proxyRequest = this._client.newRequest(targetUri) .method(HttpMethod.fromString(request.getMethod())) .version(HttpVersion.fromString(request.getProtocol())); // Copy headers for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();) { String headerName = headerNames.nextElement(); String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); // Remove hop-by-hop headers if (HOP_HEADERS.contains(lowerHeaderName)) continue; if ((this._hostHeader != null) && lowerHeaderName.equals("host")) continue; for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();) { String headerValue = headerValues.nextElement(); if (headerValue != null) proxyRequest.header(headerName, headerValue); } } // Force the Host header if configured if (this._hostHeader != null) proxyRequest.header(HttpHeader.HOST, this._hostHeader); proxyRequest.content(contentProvider); this.customizeProxyRequest(proxyRequest, request); proxyRequest.timeout(this.getTimeout(), TimeUnit.MILLISECONDS); return proxyRequest; } protected void onResponseHeaders(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) { for (HttpField field : proxyResponse.getHeaders()) { String headerName = field.getName(); String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); if (HOP_HEADERS.contains(lowerHeaderName)) continue; String newHeaderValue = this.filterResponseHeader(request, headerName, field.getValue()); if ((newHeaderValue == null) || (newHeaderValue.trim().length() == 0)) continue; response.addHeader(headerName, newHeaderValue); } } protected void onResponseContent(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, byte[] buffer, int offset, int length) throws IOException { response.getOutputStream().write(buffer, offset, length); this._log.debug("{} proxying content to downstream: {} bytes", this.getRequestId(request), length); } protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) { AsyncContext asyncContext = (AsyncContext) request.getAttribute(ASYNC_CONTEXT); asyncContext.complete(); this._log.debug("{} proxying successful", this.getRequestId(request)); } protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure) { this._log.debug(this.getRequestId(request) + " proxying failed", failure); if (!response.isCommitted()) { if (failure instanceof TimeoutException) response.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT); else response.setStatus(HttpServletResponse.SC_BAD_GATEWAY); } AsyncContext asyncContext = (AsyncContext) request.getAttribute(ASYNC_CONTEXT); asyncContext.complete(); } protected int getRequestId(HttpServletRequest request) { return System.identityHashCode(request); } protected URI rewriteURI(HttpServletRequest request) { if (!this.validateDestination(request.getServerName(), request.getServerPort())) return null; StringBuffer uri = request.getRequestURL(); String query = request.getQueryString(); if (query != null) uri.append("?").append(query); return URI.create(uri.toString()); } /** * Extension point for subclasses to customize the proxy request. * The default implementation does nothing. * * @param proxyRequest the proxy request to customize * @param request the request to be proxied */ protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest request) { } /** * Extension point for remote server response header filtering. * The default implementation returns the header value as is. * If null is returned, this header won't be forwarded back to the client. * * @param headerName the header name * @param headerValue the header value * @param request the request to proxy * @return filteredHeaderValue the new header value */ protected String filterResponseHeader(HttpServletRequest request, String headerName, String headerValue) { return headerValue; } /** * Transparent Proxy. * <p/> * This convenience extension to ProxyServlet configures the servlet as a transparent proxy. * The servlet is configured with init parameters: * <ul> * <li>proxyTo - a URI like http://host:80/context to which the request is proxied. * <li>prefix - a URI prefix that is striped from the start of the forwarded URI. * </ul> * For example, if a request is received at /foo/bar and the 'proxyTo' parameter is "http://host:80/context" * and the 'prefix' parameter is "/foo", then the request would be proxied to "http://host:80/context/bar". */ public static class Transparent extends ProxyServlet { private String _proxyTo; private String _prefix; public Transparent() { } public Transparent(String proxyTo, String prefix) { this._proxyTo = URI.create(proxyTo).normalize().toString(); this._prefix = URI.create(prefix).normalize().toString(); } @Override public void init() throws ServletException { super.init(); ServletConfig config = this.getServletConfig(); String prefix = config.getInitParameter("prefix"); this._prefix = prefix == null ? this._prefix : prefix; // Adjust prefix value to account for context path String contextPath = this.getServletContext().getContextPath(); this._prefix = this._prefix == null ? contextPath : (contextPath + this._prefix); String proxyTo = config.getInitParameter("proxyTo"); this._proxyTo = proxyTo == null ? this._proxyTo : proxyTo; if (this._proxyTo == null) throw new UnavailableException("Init parameter 'proxyTo' is required."); if (!this._prefix.startsWith("/")) throw new UnavailableException("Init parameter 'prefix' parameter must start with a '/'."); this._log.debug(config.getServletName() + " @ " + this._prefix + " to " + this._proxyTo); } @Override protected URI rewriteURI(HttpServletRequest request) { String path = request.getRequestURI(); if (!path.startsWith(this._prefix)) return null; StringBuilder uri = new StringBuilder(this._proxyTo); uri.append(path.substring(this._prefix.length())); String query = request.getQueryString(); if (query != null) uri.append("?").append(query); URI rewrittenURI = URI.create(uri.toString()).normalize(); if (!this.validateDestination(rewrittenURI.getHost(), rewrittenURI.getPort())) return null; return rewrittenURI; } } private class ProxyRequestHandler extends Response.Listener.Empty { // リトライのために記憶するデータ量 private static final int RETRY_MAX_SIZE = 256 * 1024; private final HttpServletRequest request; private final HttpServletResponse response; private final URI targetUri; private final InputStream contentInputStream; private final ByteArrayOutputStream contentBuffer = new ByteArrayOutputStream(); private boolean retryEnabled = true; public ProxyRequestHandler(HttpServletRequest request, HttpServletResponse response, URI targetUri) throws IOException { this.request = request; this.response = response; this.targetUri = targetUri; this.contentInputStream = request.getInputStream(); } /** * retryEnabled の時だけだよ * @return */ private ContentProvider createRetryContentProvider() { final int requestId = ProxyServlet.this.getRequestId(this.request); final HttpServletRequest request = this.request; return new InputStreamContentProvider( new SequenceInputStream(new ByteArrayInputStream(this.contentBuffer.toByteArray()), this.contentInputStream)) { @Override public long getLength() { return request.getContentLength(); } @Override protected ByteBuffer onRead(byte[] buffer, int offset, int length) { ProxyServlet.this._log .debug("{} proxying content to upstream: {} bytes", requestId, length); return super.onRead(buffer, offset, length); } }; } public void send() { final int requestId = ProxyServlet.this.getRequestId(this.request); final HttpServletRequest request = this.request; final ByteArrayOutputStream contentBuffer = this.contentBuffer; Request proxyRequest = ProxyServlet.this.createProxyRequest(request, this.response, this.targetUri, new InputStreamContentProvider(this.contentInputStream) { @Override public long getLength() { return request.getContentLength(); } @Override protected ByteBuffer onRead(byte[] buffer, int offset, int length) { if (length > 0) { if (contentBuffer.size() < RETRY_MAX_SIZE) { contentBuffer.write(buffer, offset, length); } else { // データが多すぎ、リトライ不可 ProxyRequestHandler.this.retryEnabled = false; } } ProxyServlet.this._log .debug("{} proxying content to upstream: {} bytes", requestId, length); return super.onRead(buffer, offset, length); } }); if (ProxyServlet.this._log.isDebugEnabled()) { StringBuilder builder = new StringBuilder(this.request.getMethod()); builder.append(" ").append(this.request.getRequestURI()); String query = this.request.getQueryString(); if (query != null) builder.append("?").append(query); builder.append(" ").append(this.request.getProtocol()).append("\r\n"); for (Enumeration<String> headerNames = this.request.getHeaderNames(); headerNames.hasMoreElements();) { String headerName = headerNames.nextElement(); builder.append(headerName).append(": "); for (Enumeration<String> headerValues = this.request.getHeaders(headerName); headerValues .hasMoreElements();) { String headerValue = headerValues.nextElement(); if (headerValue != null) builder.append(headerValue); if (headerValues.hasMoreElements()) builder.append(","); } builder.append("\r\n"); } builder.append("\r\n"); ProxyServlet.this._log.debug("{} proxying to upstream:{}{}{}{}", requestId, System.lineSeparator(), builder, proxyRequest, System.lineSeparator(), proxyRequest.getHeaders().toString().trim()); } proxyRequest.send(this); } @Override public void onBegin(Response proxyResponse) { // 返事があったらサーバ側での処理は完了しているのでリトライしない this.retryEnabled = false; this.response.setStatus(proxyResponse.getStatus()); } @Override public void onHeaders(Response proxyResponse) { ProxyServlet.this.onResponseHeaders(this.request, this.response, proxyResponse); if (ProxyServlet.this._log.isDebugEnabled()) { StringBuilder builder = new StringBuilder("\r\n"); builder.append(this.request.getProtocol()).append(" ").append(this.response.getStatus()).append(" ") .append(proxyResponse.getReason()).append("\r\n"); for (String headerName : this.response.getHeaderNames()) { builder.append(headerName).append(": "); for (Iterator<String> headerValues = this.response.getHeaders(headerName).iterator(); headerValues .hasNext();) { String headerValue = headerValues.next(); if (headerValue != null) builder.append(headerValue); if (headerValues.hasNext()) builder.append(","); } builder.append("\r\n"); } ProxyServlet.this._log.debug("{} proxying to downstream:{}{}{}{}{}", ProxyServlet.this.getRequestId(this.request), System.lineSeparator(), proxyResponse, System.lineSeparator(), proxyResponse.getHeaders().toString().trim(), System.lineSeparator(), builder); } } @Override public void onContent(Response proxyResponse, ByteBuffer content) { byte[] buffer; int offset; int length = content.remaining(); if (content.hasArray()) { buffer = content.array(); offset = content.arrayOffset(); } else { buffer = new byte[length]; content.get(buffer); offset = 0; } try { ProxyServlet.this.onResponseContent(this.request, this.response, proxyResponse, buffer, offset, length); } catch (IOException x) { proxyResponse.abort(x); } } @Override public void onSuccess(Response proxyResponse) { ProxyServlet.this.onResponseSuccess(this.request, this.response, proxyResponse); } private boolean isRetry(Throwable failure) { return this.retryEnabled && (failure instanceof EOFException) && (HttpVersion.fromString(this.request.getProtocol()) == HttpVersion.HTTP_1_1); } @Override public void onFailure(Response proxyResponse, Throwable failure) { if (!this.isRetry(failure)) { // リトライしない this.retryEnabled = false; ProxyServlet.this.onResponseFailure(this.request, this.response, proxyResponse, failure); } } @Override public void onComplete(Result result) { if (this.retryEnabled) { // 再度リトライはしない this.retryEnabled = false; ProxyServlet.this._log.debug("{} retrying proxy request", ProxyServlet.this.getRequestId(this.request)); Request proxyRequest = ProxyServlet.this.createProxyRequest(this.request, this.response, this.targetUri, this.createRetryContentProvider()); proxyRequest.send(this); } else { ProxyServlet.this._log.debug("{} proxying complete", ProxyServlet.this.getRequestId(this.request)); } } } }