package de.is24.infrastructure.gridfs.http.security; import de.is24.infrastructure.gridfs.http.utils.HostnameResolver; import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import javax.servlet.http.HttpServletRequest; import java.net.Inet4Address; import java.net.Inet6Address; import java.net.UnknownHostException; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class WhiteListAuthenticationFilterTest { private static final String LOCAL_IP = "127.0.0.1"; private static final String LOCAL_IPv6 = "0:0:0:0:0:0:0:1"; private static final String ARBITRARY_HOST_RESOLVABLE_IPv6 = "abc:def:123:456:789:aaaa:bbbb:cccc"; private static final String LOADBALANCER_IP = "10.99.10.12"; private static final String X_FORWARDED_FOR = "X-Forwarded-For"; private static final String ARBITRARY_IP = "192.168.5.5"; private static final String ANOTHER_IP = "192.168.6.6"; private static final String BASE64_AUTH_STRING = "Basic Zm9vOmJhcg=="; private HostnameResolver hostnameResolver; @Before public void setup() { this.hostnameResolver = new HostnameResolver(LOADBALANCER_IP); } @Test public void detectHostnameFromIP() throws Exception { WhiteListAuthenticationFilter filter = createFilter(localHostname()); assertThat(filter.getPreAuthenticatedPrincipal(request(LOCAL_IP)), notNullValue()); } @Test public void detectHostnameFromIPv6() throws Exception { WhiteListAuthenticationFilter filter = createFilter(localIPv6Hostname()); assertThat(filter.getPreAuthenticatedPrincipal(request(LOCAL_IPv6)), notNullValue()); } @Test public void handleUnresolvableIPv6Addresses() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_HOST_RESOLVABLE_IPv6); assertThat(filter.getPreAuthenticatedPrincipal(request(ARBITRARY_HOST_RESOLVABLE_IPv6)), notNullValue()); } @Test public void allowLoadBalancerRequestsWithXForwardedFor() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); assertThat(filter.getPreAuthenticatedCredentials(request(LOADBALANCER_IP, ARBITRARY_IP)), notNullValue()); } @Test public void allowLoadBalancerRequestsWithXForwardedForChain() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); assertThat(filter.getPreAuthenticatedCredentials(request(LOADBALANCER_IP, ANOTHER_IP + ", " + ARBITRARY_IP)), notNullValue()); } @Test public void allowLoadBalancerRequestsWithXForwardedForResolvableHostname() throws Exception { WhiteListAuthenticationFilter filter = createFilter(localHostname()); assertThat(filter.getPreAuthenticatedCredentials(request(LOADBALANCER_IP, LOCAL_IP)), notNullValue()); } @Test public void denyLoadBalancerRequestsWithUnauthorizedXForwardedFor() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); assertThat(filter.getPreAuthenticatedCredentials(request(LOADBALANCER_IP, ANOTHER_IP)), nullValue()); } @Test public void denyLoadBalancerRequestsWithUnauthorizedXForwardedForChain() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); assertThat(filter.getPreAuthenticatedCredentials(request(LOADBALANCER_IP, ARBITRARY_IP + ", " + ANOTHER_IP)), nullValue()); } @Test public void denyLoadBalancerRequestsWithoutXForwardedFor() throws Exception { WhiteListAuthenticationFilter filter = createFilter(""); assertThat(filter.getPreAuthenticatedCredentials(request(LOADBALANCER_IP)), nullValue()); } @Test(expected = IllegalStateException.class) public void denySettingWhiteListWithoutFeatureEnabled() throws Exception { WhiteListAuthenticationFilter filter = createFilter(""); filter.setWhiteListedHosts("foo.bar"); } @Test public void setWhiteListViaProperty() throws Exception { WhiteListAuthenticationFilter filter = createFilter("", true); filter.setWhiteListedHosts(ARBITRARY_IP); assertThat(filter.getWhiteListedHosts(), is(ARBITRARY_IP)); } @Test public void setUsernameHeaderAsPrincipalForWhitelistedHosts() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); MockHttpServletRequest request = request(ARBITRARY_IP); request.addHeader("Username", "foo"); assertThat(filter.getPreAuthenticatedPrincipal(request), is("foo")); } @Test public void setHostnameAsPrincipalForWhitelistedHosts() throws Exception { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); MockHttpServletRequest request = request(ARBITRARY_IP); assertThat(filter.getPreAuthenticatedPrincipal(request), is(ARBITRARY_IP)); } @Test public void denyWhiteListedHostIfBasicAuthHeaderIsGiven() { WhiteListAuthenticationFilter filter = createFilter(ARBITRARY_IP); MockHttpServletRequest request = request(ARBITRARY_IP); request.addHeader("Authorization", BASE64_AUTH_STRING); assertThat(filter.getPreAuthenticatedPrincipal(request), nullValue()); } private WhiteListAuthenticationFilter createFilter(String whiteListedHosts) { return createFilter(whiteListedHosts, false); } private WhiteListAuthenticationFilter createFilter(String whiteListedHosts, boolean enableModification) { return new WhiteListAuthenticationFilter(whiteListedHosts, enableModification, null, hostnameResolver); } private String localHostname() throws UnknownHostException { return Inet4Address.getByName(LOCAL_IP).getHostName(); } private String localIPv6Hostname() throws UnknownHostException { return Inet6Address.getByName(LOCAL_IPv6).getHostName(); } private MockHttpServletRequest request(String ip) { MockHttpServletRequest request = new MockHttpServletRequest(); request.setRemoteHost(ip); request.setRemoteAddr(ip); return request; } private HttpServletRequest request(String loadBalancerIP, String ip) { MockHttpServletRequest request = request(loadBalancerIP); request.addHeader(X_FORWARDED_FOR, ip); return request; } }