/**
* Copyright (c) Codice Foundation
* <p/>
* This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
* General Public License as published by the Free Software Foundation, either version 3 of the
* License, or any later version.
* <p/>
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details. A copy of the GNU Lesser General Public License
* is distributed along with this program and can be found at
* <http://www.gnu.org/licenses/lgpl.html>.
*/
package org.codice.ddf.platform.filter.delegate;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.osgi.framework.BundleContext;
import org.osgi.framework.InvalidSyntaxException;
import org.osgi.framework.ServiceReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Tests that the DelegateServletFilter is functionality properly.
*
*/
public class DelegateServletFilterTest {
private static final Logger LOGGER = LoggerFactory.getLogger(DelegateServletFilterTest.class);
FilterChain initialChain;
private Filter filter1;
private Filter filter2;
private Filter filter3;
@Before
public void resetGlobals() {
initialChain = mock(FilterChain.class);
}
/**
* Tests the main logic of performing the filter with adding filters.
*
* @throws ServletException
* @throws IOException
* @throws InvalidSyntaxException
*/
@Test
public void testDoFilterWithFilters()
throws IOException, ServletException, InvalidSyntaxException {
ServletRequest request = mock(HttpServletRequest.class);
ServletResponse response = mock(HttpServletResponse.class);
final BundleContext context = createMockContext(true);
DelegateServletFilter filter = new DelegateServletFilter() {
@Override
protected BundleContext getContext() {
return context;
}
};
filter.doFilter(request, response, initialChain);
verifyFiltersCalled(request, response, initialChain);
}
/**
* Tests the main logic of performing the filter with no incoming filters.
*
* @throws ServletException
* @throws IOException
* @throws InvalidSyntaxException
*/
@Test
public void testDoFilterWithNoFilters()
throws IOException, ServletException, InvalidSyntaxException {
ServletRequest request = mock(HttpServletRequest.class);
ServletResponse response = mock(HttpServletResponse.class);
final BundleContext context = createMockContext(false);
DelegateServletFilter filter = new DelegateServletFilter() {
@Override
protected BundleContext getContext() {
return context;
}
};
filter.doFilter(request, response, initialChain);
verifyFiltersNotCalled(request, response, initialChain);
}
private void verifyFiltersCalled(ServletRequest request, ServletResponse response,
FilterChain initialChain) throws IOException, ServletException {
// verify that all of the filters were called once
verify(filter1).doFilter(eq(request), eq(response), any(FilterChain.class));
verify(filter2).doFilter(eq(request), eq(response), any(FilterChain.class));
verify(filter3).doFilter(eq(request), eq(response), any(FilterChain.class));
// verify initial chain was called once
verify(initialChain).doFilter(request, response);
}
private void verifyFiltersNotCalled(ServletRequest request, ServletResponse response,
FilterChain initialChain) throws IOException, ServletException {
// verify that none of the filters were called
verify(filter1, never()).doFilter(eq(request), eq(response), any(FilterChain.class));
verify(filter2, never()).doFilter(eq(request), eq(response), any(FilterChain.class));
verify(filter3, never()).doFilter(eq(request), eq(response), any(FilterChain.class));
// verify initial chain was called once
verify(initialChain).doFilter(request, response);
}
private List<Filter> mockFilters(boolean includeFilters)
throws InvalidSyntaxException, IOException, ServletException {
List<Filter> filters = new ArrayList<Filter>(3);
filter1 = createMockFilter("filter1");
filter2 = createMockFilter("filter2");
filter3 = createMockFilter("filter3");
if (includeFilters) {
filters.add(filter1);
filters.add(filter2);
filters.add(filter3);
}
return filters;
}
private Filter createMockFilter(final String name) throws IOException, ServletException {
Filter mockFilter = mock(Filter.class);
when(mockFilter.toString()).thenReturn(name);
Mockito.doAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
LOGGER.debug("{} was called.", name);
((FilterChain) args[2])
.doFilter(((ServletRequest) args[0]), ((ServletResponse) args[1]));
return null;
}
})
.when(mockFilter).doFilter(any(ServletRequest.class), any(ServletResponse.class),
any(FilterChain.class));
return mockFilter;
}
private BundleContext createMockContext(boolean includeFilters)
throws InvalidSyntaxException, IOException, ServletException {
BundleContext context = mock(BundleContext.class);
List<Filter> mockFilters = mockFilters(includeFilters);
List<ServiceReference<Filter>> referenceList = new ArrayList<ServiceReference<Filter>>();
for (Filter curFilter : mockFilters) {
ServiceReference<Filter> mockRef = mock(ServiceReference.class);
when(context.getService(mockRef)).thenReturn(curFilter);
referenceList.add(mockRef);
}
when(context.getServiceReferences(Filter.class, null)).thenReturn(referenceList);
return context;
}
}