/**
* Copyright 2013 the original author or authors.
* <p/>
* Licensed under the Apache License, Version 2.0 the "License";
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p/>
* http://www.apache.org/licenses/LICENSE-2.0
* <p/>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
**/
package io.neba.core.web;
import io.neba.core.web.BackgroundServletRequestWrapper;
import io.neba.core.web.NebaRequestContextFilter;
import org.apache.sling.bgservlets.BackgroundHttpServletRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.context.i18n.LocaleContext;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.util.Locale;
import java.util.concurrent.ExecutorService;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.apache.commons.io.IOUtils.toByteArray;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.*;
import static org.springframework.context.i18n.LocaleContextHolder.getLocaleContext;
import static org.springframework.util.ReflectionUtils.findMethod;
import static org.springframework.web.context.request.RequestAttributes.SCOPE_REQUEST;
import static org.springframework.web.context.request.RequestContextHolder.getRequestAttributes;
/**
* @author Olaf Otto
*/
@RunWith(MockitoJUnitRunner.class)
public class NebaRequestContextFilterTest {
private final ExecutorService executorService = newSingleThreadExecutor();
@Mock
private HttpServletRequest request;
@Mock
private HttpServletResponse response;
@Mock
private FilterChain chain;
@Mock
private Runnable destructionCallback;
private final Locale locale = Locale.ENGLISH;
private RequestAttributes requestAttributes;
private LocaleContext localeContext;
private RequestAttributes inheritedRequestAttributes;
private LocaleContext inheritedLocalContext;
@InjectMocks
private NebaRequestContextFilter testee;
@Before
public void setUp() throws IOException, ServletException {
doAnswer(invocationOnMock -> {
requestAttributes = getRequestAttributes();
if (requestAttributes != null) {
requestAttributes.registerDestructionCallback("TEST CALLBACK", destructionCallback, SCOPE_REQUEST);
}
localeContext = getLocaleContext();
executorService.submit(() -> {
inheritedRequestAttributes = getRequestAttributes();
inheritedLocalContext = getLocaleContext();
}).get();
return null;
}).when(chain).doFilter(isA(HttpServletRequest.class), isA(HttpServletResponse.class));
doReturn(locale).when(request).getLocale();
}
@After
public void assertThreadLocalesAreRemoved() throws Exception {
assertThat(getRequestAttributes()).isNull();
assertThat(getLocaleContext()).isNull();
}
@After
public void verifyServletAttributesAreCompleted() throws Exception {
verify(destructionCallback).run();
}
@Test
public void testForegroundRequestAreNotWrappedWithBackgroundRequestWrapper() throws Exception {
doFilter();
assertRequestAttributesAreProvided();
assertExposedRequestIsNotModified();
}
@Test
public void testBackgroundRequestsAreWrappedWithBackgroundRequestWrapper() throws Exception {
withBackgroundRequest();
doFilter();
assertRequestAttributesAreProvided();
assertExposedRequestIsBackgroundRequestWrapper();
}
@Test
public void testFilterProvidesRequestLocaleInLocaleContext() throws Exception {
doFilter();
assertLocaleContextProvidesRequestLocale();
}
@Test
public void testContextsAreInheritedWhenThreadContextInheritanceIsTrue() throws Exception {
withThreadContextInheritable();
doFilter();
assertContextsAreInherited();
}
@Test
public void testContextAreNotInheritedByDefault() throws Exception {
doFilter();
assertContextsAreNotInherited();
}
@Test
public void testFilterToleratesAbsenceOfOptionalDependencyToBgHttpServletRequest() throws Exception {
ClassLoader classLoaderWithoutBgServlets = new ClassLoader(getClass().getClassLoader()) {
@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
if (BackgroundHttpServletRequest.class.getName().equals(name)) {
// This optional dependency is not present on the class path in this test scenario.
throw new ClassNotFoundException("THIS IS AN EXPECTED TEST EXCEPTION. The dependency to bgservlets is optional.");
}
if (NebaRequestContextFilter.class.getName().equals(name)) {
// Define the test subject's class class in this class loader, thus its dependencies -
// such as the background servlet request - are also loaded via this class loader.
try {
byte[] classFileData = toByteArray(getResourceAsStream(name.replace('.', '/').concat(".class")));
return defineClass(name, classFileData, 0, classFileData.length);
} catch (IOException e) {
throw new ClassNotFoundException("Unable to load " + name + ".", e);
}
}
return super.loadClass(name);
}
};
Class<?> filterClass = classLoaderWithoutBgServlets.loadClass(NebaRequestContextFilter.class.getName());
assertThat(valueOfField(filterClass, "IS_BGSERVLETS_PRESENT")).isFalse();
Object filter = filterClass.newInstance();
invoke(filter, "doFilter", request, response, chain);
}
private void invoke(Object o, String methodName, Object... args) throws InvocationTargetException, IllegalAccessException {
findMethod(
o.getClass(),
methodName,
ServletRequest.class, ServletResponse.class, FilterChain.class)
.invoke(o, args);
}
private boolean valueOfField(Class type, String fieldName) throws NoSuchFieldException, IllegalAccessException {
Field declaredField = type.getDeclaredField(fieldName);
declaredField.setAccessible(true);
return (boolean) declaredField.get(this.testee);
}
private void assertContextsAreNotInherited() {
assertThat(inheritedRequestAttributes).isNull();
assertThat(inheritedLocalContext).isNull();
}
private void assertContextsAreInherited() {
assertThat(inheritedRequestAttributes).isEqualTo(requestAttributes);
assertThat(inheritedLocalContext).isEqualTo(localeContext);
}
private void withThreadContextInheritable() {
testee.setThreadContextInheritable(true);
}
private void assertLocaleContextProvidesRequestLocale() {
assertThat(localeContext).isNotNull();
assertThat(localeContext.getLocale()).isSameAs(locale);
}
private void assertExposedRequestIsBackgroundRequestWrapper() {
assertThat(((ServletRequestAttributes) requestAttributes).getRequest()).isInstanceOf(BackgroundServletRequestWrapper.class);
}
private void assertRequestAttributesAreProvided() {
assertThat(requestAttributes).isInstanceOf(ServletRequestAttributes.class);
}
private void withBackgroundRequest() {
request = mock(BackgroundHttpServletRequest.class);
}
private void doFilter() throws ServletException, IOException {
testee.doFilter(request, response, chain);
}
private void assertExposedRequestIsNotModified() {
assertThat(((ServletRequestAttributes) requestAttributes).getRequest()).isSameAs(request);
}
}