//======================================================================== //Copyright 2004-2008 Mort Bay Consulting Pty. Ltd. //------------------------------------------------------------------------ //Licensed under the Apache License, Version 2.0 (the "License"); //you may not use this file except in compliance with the License. //You may obtain a copy of the License at //http://www.apache.org/licenses/LICENSE-2.0 //Unless required by applicable law or agreed to in writing, software //distributed under the License is distributed on an "AS IS" BASIS, //WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //See the License for the specific language governing permissions and //limitations under the License. //======================================================================== package org.mortbay.jetty; import java.io.IOException; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import junit.framework.TestCase; import org.mortbay.jetty.handler.AbstractHandler; /** * Test {@link AbstractConnector#checkForwardedHeaders(org.mortbay.io.EndPoint, Request)}. */ public class CheckReverseProxyHeadersTest extends TestCase { Server server=new Server(); LocalConnector connector=new LocalConnector(); /** * Constructor for CheckReverseProxyHeadersTest. * @param name test case name. */ public CheckReverseProxyHeadersTest(String name) { super(name); } public void testCheckReverseProxyHeaders() throws Exception { // Classic ProxyPass from example.com:80 to localhost:8080 testRequest("Host: localhost:8080\n" + "X-Forwarded-For: 10.20.30.40\n" + "X-Forwarded-Host: example.com", new RequestValidator() { public void validate(HttpServletRequest request) { assertEquals("example.com", request.getServerName()); assertEquals(80, request.getServerPort()); assertEquals("10.20.30.40", request.getRemoteAddr()); assertEquals("10.20.30.40", request.getRemoteHost()); assertEquals("example.com", request.getHeader("Host")); } }); // ProxyPass from example.com:81 to localhost:8080 testRequest("Host: localhost:8080\n" + "X-Forwarded-For: 10.20.30.40\n" + "X-Forwarded-Host: example.com:81\n" + "X-Forwarded-Server: example.com", new RequestValidator() { public void validate(HttpServletRequest request) { assertEquals("example.com", request.getServerName()); assertEquals(81, request.getServerPort()); assertEquals("10.20.30.40", request.getRemoteAddr()); assertEquals("10.20.30.40", request.getRemoteHost()); assertEquals("example.com:81", request.getHeader("Host")); } }); // Multiple ProxyPass from example.com:80 to rp.example.com:82 to localhost:8080 testRequest("Host: localhost:8080\n" + "X-Forwarded-For: 10.20.30.40, 10.0.0.1\n" + "X-Forwarded-Host: example.com, rp.example.com:82\n" + "X-Forwarded-Server: example.com, rp.example.com", new RequestValidator() { public void validate(HttpServletRequest request) { assertEquals("example.com", request.getServerName()); assertEquals(80, request.getServerPort()); assertEquals("10.20.30.40", request.getRemoteAddr()); assertEquals("10.20.30.40", request.getRemoteHost()); assertEquals("example.com", request.getHeader("Host")); } }); } private void testRequest(String headers, RequestValidator requestValidator) throws Exception { Server server = new Server(); LocalConnector connector = new LocalConnector(); // Activate reverse proxy headers checking connector.setForwarded(true); server.setConnectors(new Connector[] {connector}); ValidationHandler validationHandler = new ValidationHandler(requestValidator); server.setHandler(validationHandler); try { server.start(); connector.getResponses("GET / HTTP/1.1\n" + headers + "\n\n"); Error error = validationHandler.getError(); if (error != null) { throw error; } } finally { server.stop(); } } /** * Interface for validate a wrapped request. */ private static interface RequestValidator { /** * Validate the current request. * @param request the request. */ void validate(HttpServletRequest request); } /** * Handler for validation. */ private static class ValidationHandler extends AbstractHandler { private RequestValidator _requestValidator; private Error _error; /** * Create the validation handler with a request validator. * @param requestValidator the request validator. */ public ValidationHandler(RequestValidator requestValidator) { _requestValidator = requestValidator; } /** * Retrieve the validation error. * @return the validation error or <code>null</code> if there was no error. */ public Error getError() { return _error; } public void handle(String target, HttpServletRequest request, HttpServletResponse response, int dispatch) throws IOException, ServletException { try { _requestValidator.validate(request); } catch (Error e) { _error = e; } catch (Throwable e) { _error = new Error(e); } } } }