/*******************************************************************************
* Copyright (c) 2012-2016 Codenvy, S.A.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Codenvy, S.A. - initial API and implementation
*******************************************************************************/
package org.everrest.core.servlet;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.HttpHeaders;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static java.util.Collections.enumeration;
import static org.everrest.core.ExtHttpHeaders.FORWARDED_HOST;
import static org.everrest.core.ExtHttpHeaders.FORWARDED_PROTO;
import static org.junit.Assert.assertEquals;
import static org.junit.runners.Parameterized.Parameter;
import static org.junit.runners.Parameterized.Parameters;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Test for {@link ServletContainerRequest}
*
* @author Tareq Sharafy <tareq.sharafy@sap.com>
*/
@RunWith(Parameterized.class)
public class ServletContainerRequestForwardedHeadersTest {
private static final String TEST_HOST = "test.myhost.com";
private static final int TEST_PORT = 8080;
@Parameters(name = "{index} When X-Forwarded-Host header is {0} then Request URI is {2}")
public static List<Object[]> testData() {
return Arrays.asList(new Object[][]{
// --- Invalid forwarded headers, ignored
{"a b c", createBaseUri(TEST_HOST, TEST_PORT), createRequestUri(TEST_HOST, TEST_PORT)},
{"myhost.com:8877:200", createBaseUri(TEST_HOST, TEST_PORT), createRequestUri(TEST_HOST, TEST_PORT)},
{"myhost..com", createBaseUri(TEST_HOST, TEST_PORT), createRequestUri(TEST_HOST, TEST_PORT)},
// ---
{"other.myhost.com", createBaseUri("other.myhost.com"), createRequestUri("other.myhost.com")},
{"other.myhost.com:777", createBaseUri("other.myhost.com", 777), createRequestUri("other.myhost.com", 777)}
});
}
private static String createBaseUri(String host, int port) {
return String.format("http://%s:%d/myapp/myservlet", host, port);
}
private static String createRequestUri(String host, int port) {
return String.format("http://%s:%d/myapp/myservlet/datapath", host, port);
}
private static String createBaseUri(String host) {
return String.format("http://%s/myapp/myservlet", host);
}
private static String createRequestUri(String host) {
return String.format("http://%s/myapp/myservlet/datapath", host);
}
//
@Parameter(0)
public String forwardedHost;
@Parameter(1)
public String expectedBaseUri;
@Parameter(2)
public String expectedRequestUri;
private HttpServletRequest httpServletRequest;
private Map<String, List<String>> httpHeaders;
@Before
public void setUp() {
httpServletRequest = mock(HttpServletRequest.class);
httpHeaders = createForwardedHeaders(forwardedHost, null);
when(httpServletRequest.getHeaderNames()).thenReturn(enumeration(httpHeaders.keySet()));
when(httpServletRequest.getScheme()).thenReturn("http");
when(httpServletRequest.getServerName()).thenReturn(TEST_HOST);
when(httpServletRequest.getServerPort()).thenReturn(TEST_PORT);
when(httpServletRequest.getContextPath()).thenReturn("/myapp");
when(httpServletRequest.getServletPath()).thenReturn("/myservlet");
when(httpServletRequest.getPathInfo()).thenReturn("/myapp");
when(httpServletRequest.getRequestURI()).thenReturn("/myapp/myservlet/datapath");
when(httpServletRequest.getHeaders(any(String.class))).thenAnswer(getHeadersByName());
when(httpServletRequest.getHeader(any(String.class))).thenAnswer(getHeaderByName());
}
private Map<String, List<String>> createForwardedHeaders(String forwardedHost, String forwardedProto) {
Map<String, List<String>> finalHeaders = new HashMap<>();
if (forwardedHost != null) {
finalHeaders.put(FORWARDED_HOST, Arrays.asList(forwardedHost));
}
if (forwardedProto != null) {
finalHeaders.put(FORWARDED_PROTO, Arrays.asList(forwardedProto));
}
finalHeaders.put(HttpHeaders.HOST, Arrays.asList(TEST_HOST + ":" + TEST_PORT));
return finalHeaders;
}
private Answer<Enumeration<String>> getHeadersByName() {
return new Answer<Enumeration<String>>() {
@Override
public Enumeration<String> answer(InvocationOnMock invocation) throws Throwable {
String name = (String)invocation.getArguments()[0];
return enumeration(httpHeaders.get(name));
}
};
}
private Answer<String> getHeaderByName() {
return new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
String name = (String)invocation.getArguments()[0];
List<String> values = httpHeaders.get(name);
return values == null || values.isEmpty() ? null : values.get(0);
}
};
}
@Test
public void testForwarded() {
ServletContainerRequest req = ServletContainerRequest.create(httpServletRequest);
assertEquals(expectedBaseUri, req.getBaseUri().toString());
assertEquals(expectedRequestUri, req.getRequestUri().toString());
}
}