/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.yarn.server.webproxy.amfilter; import java.io.IOException; import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import junit.framework.Assert; import org.junit.Test; import org.mockito.Mockito; public class TestAmFilter { private String proxyHost = "bogushost.com"; private String proxyUri = "http://bogus"; private class TestAmIpFilter extends AmIpFilter { private Set<String> proxyAddresses = null; protected Set<String> getProxyAddresses() { if(proxyAddresses == null) { proxyAddresses = new HashSet<String>(); } proxyAddresses.add(proxyHost); return proxyAddresses; } } private static class DummyFilterConfig implements FilterConfig { final Map<String, String> map; DummyFilterConfig(Map<String,String> map) { this.map = map; } @Override public String getFilterName() { return "dummy"; } @Override public String getInitParameter(String arg0) { return map.get(arg0); } @Override public Enumeration<String> getInitParameterNames() { return Collections.enumeration(map.keySet()); } @Override public ServletContext getServletContext() { return null; } } @Test public void filterNullCookies() throws Exception { HttpServletRequest request = Mockito.mock(HttpServletRequest.class); Mockito.when(request.getCookies()).thenReturn(null); Mockito.when(request.getRemoteAddr()).thenReturn(proxyHost); HttpServletResponse response = Mockito.mock(HttpServletResponse.class); final AtomicBoolean invoked = new AtomicBoolean(); FilterChain chain = new FilterChain() { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException, ServletException { invoked.set(true); } }; Map<String, String> params = new HashMap<String, String>(); params.put(AmIpFilter.PROXY_HOST, proxyHost); params.put(AmIpFilter.PROXY_URI_BASE, proxyUri); FilterConfig conf = new DummyFilterConfig(params); Filter filter = new TestAmIpFilter(); filter.init(conf); filter.doFilter(request, response, chain); Assert.assertTrue(invoked.get()); filter.destroy(); } }