// Copyright 2011 Google, Inc. All Rights Reserved. package com.google.jstestdriver.server.gateway; import com.google.common.collect.Iterators; import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; import com.google.jstestdriver.requesthandlers.HttpMethod; import com.google.jstestdriver.requesthandlers.RequestHandler; import org.apache.commons.httpclient.Header; import org.apache.commons.httpclient.HttpClient; import org.apache.commons.httpclient.HttpMethodBase; import org.mortbay.jetty.Response; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.util.Enumeration; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; /** * An HTTP gateway that forwards requests to another server and feeds the responses back to the * client. * TODO(rdionne): Write unit tests. * TODO(rdionne): Implement HOST header and 302 redirect host substitution. * @author rdionne@google.com (Robert Dionne) */ public class GatewayRequestHandler implements RequestHandler { public interface Factory { GatewayRequestHandler create( @Assisted("destination") String destination, @Assisted("prefix") String prefix); } private static final String HOST = "Host"; private static final String LOCATION = "Location"; private static final String PRAGMA = "Pragma"; private static final String REQUEST_URI_DOES_NOT_START_WITH_PREFIX = "Request URI '%s' does not start with prefix '%s'."; private static final String X_SUPPRESS_STATUS_CODE = "X-Suppress-Status-Code"; private static final String X_SUPPRESSED_STATUS_CODE = "X-Suppressed-Status-Code"; private static final String X_SUPPRESSED_REASON_PHRASE = "X-Suppressed-Reason-Phrase"; private final HttpClient client; private final HttpServletRequest request; private final HttpServletResponse response; private final String destination; private final String prefix; @Inject public GatewayRequestHandler( final HttpClient client, final HttpServletRequest req, final HttpServletResponse res, @Assisted("destination") String destination, @Assisted("prefix") String prefix) { this.client = client; this.request = req; this.response = res; this.destination = destination; this.prefix = prefix; } @Override public void handleIt() throws IOException { final HttpMethodBase method = getMethod(request); addRequestHeaders(method, request); method.setQueryString(request.getQueryString()); spoofHostHeader(method); try { final int statusCode = client.executeMethod(method); response.setStatus(statusCode); addResponseHeaders(method, response); if (isRedirect(statusCode)) { spoofLocationHeader(request, (Response) response); } if (isStatusCodeSuppressed(request)) { response.setStatus(HttpServletResponse.SC_OK); response.addIntHeader(X_SUPPRESSED_STATUS_CODE, statusCode); response.addHeader(X_SUPPRESSED_REASON_PHRASE, method.getStatusText()); } // TODO(rdionne): Substitute the JsTD server address for the destination address in any redirects. Streams.copy(method.getResponseBodyAsStream(), response.getOutputStream()); } catch (IOException e) { response.setStatus(HttpServletResponse.SC_BAD_GATEWAY); e.printStackTrace(response.getWriter()); } finally { method.releaseConnection(); } } private boolean isStatusCodeSuppressed(final HttpServletRequest request) { return Iterators.contains( Iterators.forEnumeration(request.getHeaders(PRAGMA)), X_SUPPRESS_STATUS_CODE); } private HttpMethodBase getMethod(final HttpServletRequest request) throws IOException { final HttpMethod method = HttpMethod.valueOf(request.getMethod()); String uri = request.getRequestURI(); if (prefix != null && !uri.startsWith(prefix)) { // Probably impossible. throw new RuntimeException( String.format(REQUEST_URI_DOES_NOT_START_WITH_PREFIX, uri, prefix)); } String url = prefix == null ? destination + uri : destination + uri.substring(prefix.length()); switch (method) { case POST: case PUT: return new GatewayEntityMethod( method.name(), url, request.getInputStream()); default: return new GatewayMethod(method.name(), url); } } private void addRequestHeaders(final HttpMethodBase method, final HttpServletRequest request) { final Enumeration headers = request.getHeaderNames(); while (headers.hasMoreElements()) { final String name = (String) headers.nextElement(); final Enumeration values = request.getHeaders(name); while (values.hasMoreElements()) { final String value = (String) values.nextElement(); method.addRequestHeader(name, value); } } } private void spoofHostHeader(HttpMethodBase method) { method.setRequestHeader(HOST, parseUri(destination).getAuthority()); } private URI parseUri(String uri) { try { return new URI(uri); } catch (URISyntaxException badUriSyntax) { throw new RuntimeException(badUriSyntax); } } private void addResponseHeaders(final HttpMethodBase method, final HttpServletResponse response) { for (final Header header : method.getResponseHeaders()) { response.addHeader(header.getName(), header.getValue()); } } private boolean isRedirect(int statusCode) { switch (statusCode) { case HttpServletResponse.SC_MOVED_PERMANENTLY: case HttpServletResponse.SC_FOUND: case HttpServletResponse.SC_SEE_OTHER: case HttpServletResponse.SC_TEMPORARY_REDIRECT: return true; default: return false; } } void spoofLocationHeader(HttpServletRequest request, Response response) { URI location = parseUri(response.getHeader(LOCATION)); String host = request.getServerName(); int port = request.getServerPort(); URI destination = parseUri(this.destination); if (location.getHost() == null || !location.getHost().equals(destination.getHost())) { return; } if (location.getPort() == destination.getPort() || location.getPort() == -1 && (destination.getPort() == 80 || destination.getPort() == 443)) { response.setHeader(LOCATION, buildLocationHeader(location, host, port)); } } private String buildLocationHeader(URI location, String host, int port) { try { return new URI( location.getScheme(), String.format("%s:%s", host, port), location.getPath(), location.getQuery(), location.getFragment()).toString(); } catch (URISyntaxException badUriSyntax) { throw new RuntimeException(badUriSyntax); } } }