/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.support;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import org.junit.Rule;
import org.junit.Test;
import org.springframework.boot.testutil.InternalOutputCapture;
import org.springframework.boot.web.server.ErrorPage;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockRequestDispatcher;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.util.NestedServletException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link ErrorPageFilter}.
*
* @author Dave Syer
* @author Andy Wilkinson
*/
public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter();
private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest();
private MockHttpServletResponse response = new MockHttpServletResponse();
private MockFilterChain chain = new MockFilterChain();
@Rule
public InternalOutputCapture output = new InternalOutputCapture();
@Test
public void notAnError() throws Exception {
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isTrue();
assertThat(this.response.getForwardedUrl()).isNull();
}
@Test
public void notAnErrorButNotOK() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).setStatus(201);
super.doFilter(request, response);
response.flushBuffer();
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponse) this.chain.getResponse()).getStatus())
.isEqualTo(201);
assertThat(((HttpServletResponse) ((HttpServletResponseWrapper) this.chain
.getResponse()).getResponse()).getStatus()).isEqualTo(201);
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void unauthorizedWithErrorPath() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(401, "UNAUTHORIZED");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
HttpServletResponseWrapper wrapper = (HttpServletResponseWrapper) this.chain
.getResponse();
assertThat(wrapper.getResponse()).isEqualTo(this.response);
assertThat(this.response.isCommitted()).isTrue();
assertThat(wrapper.getStatus()).isEqualTo(401);
// The real response has to be 401 as well...
assertThat(this.response.getStatus()).isEqualTo(401);
assertThat(this.response.getForwardedUrl()).isEqualTo("/error");
}
@Test
public void responseCommitted() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
this.response.setCommitted(true);
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(400);
assertThat(this.response.getForwardedUrl()).isNull();
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void responseUncommittedWithoutErrorPage() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(400);
assertThat(this.response.getForwardedUrl()).isNull();
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void oncePerRequest() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
assertThat(request.getAttribute("FILTER.FILTERED")).isNotNull();
super.doFilter(request, response);
}
};
this.filter.init(new MockFilterConfig("FILTER"));
this.filter.doFilter(this.request, this.response, this.chain);
}
@Test
public void globalError() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(400);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE))
.isEqualTo(400);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
assertThat(this.response.getForwardedUrl()).isEqualTo("/error");
}
@Test
public void statusError() throws Exception {
this.filter.addErrorPages(new ErrorPage(HttpStatus.BAD_REQUEST, "/400"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(400);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE))
.isEqualTo(400);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
assertThat(this.response.getForwardedUrl()).isEqualTo("/400");
}
@Test
public void statusErrorWithCommittedResponse() throws Exception {
this.filter.addErrorPages(new ErrorPage(HttpStatus.BAD_REQUEST, "/400"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
response.flushBuffer();
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(400);
assertThat(this.response.isCommitted()).isTrue();
assertThat(this.response.getForwardedUrl()).isNull();
}
@Test
public void exceptionError() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new RuntimeException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE))
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
assertThat(this.response.getForwardedUrl()).isEqualTo("/500");
}
@Test
public void exceptionErrorWithCommittedResponse() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
response.flushBuffer();
throw new RuntimeException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getForwardedUrl()).isNull();
}
@Test
public void statusCode() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
assertThat(((HttpServletResponse) response).getStatus()).isEqualTo(200);
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(200);
}
@Test
public void subClassExceptionError() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new IllegalStateException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE))
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(IllegalStateException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(IllegalStateException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void responseIsNotCommittedWhenRequestIsAsync() throws Exception {
this.request.setAsyncStarted(true);
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isFalse();
}
@Test
public void responseIsCommittedWhenRequestIsAsyncAndExceptionIsThrown()
throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
this.request.setAsyncStarted(true);
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new RuntimeException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void responseIsCommittedWhenRequestIsAsyncAndStatusIs400Plus()
throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
this.request.setAsyncStarted(true);
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
((HttpServletResponse) response).sendError(400, "BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void responseIsNotCommittedDuringAsyncDispatch() throws Exception {
setUpAsyncDispatch();
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isFalse();
}
@Test
public void responseIsCommittedWhenExceptionIsThrownDuringAsyncDispatch()
throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
setUpAsyncDispatch();
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new RuntimeException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void responseIsCommittedWhenStatusIs400PlusDuringAsyncDispatch()
throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
setUpAsyncDispatch();
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
((HttpServletResponse) response).sendError(400, "BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.chain.getRequest()).isEqualTo(this.request);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse())
.isEqualTo(this.response);
assertThat(this.response.isCommitted()).isTrue();
}
@Test
public void responseIsNotFlushedIfStatusIsLessThan400AndItHasAlreadyBeenCommitted()
throws Exception {
HttpServletResponse committedResponse = mock(HttpServletResponse.class);
given(committedResponse.isCommitted()).willReturn(true);
given(committedResponse.getStatus()).willReturn(200);
this.filter.doFilter(this.request, committedResponse, this.chain);
verify(committedResponse, times(0)).flushBuffer();
}
@Test
public void errorMessageForRequestWithoutPathInfo()
throws IOException, ServletException {
this.request.setServletPath("/test");
this.filter.addErrorPages(new ErrorPage("/error"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new RuntimeException();
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.output.toString()).contains("request [/test]");
}
@Test
public void errorMessageForRequestWithPathInfo()
throws IOException, ServletException {
this.request.setServletPath("/test");
this.request.setPathInfo("/alpha");
this.filter.addErrorPages(new ErrorPage("/error"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new RuntimeException();
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(this.output.toString()).contains("request [/test/alpha]");
}
@Test
public void nestedServletExceptionIsUnwrapped() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new NestedServletException("Wrapper", new RuntimeException("BAD"));
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus())
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE))
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
assertThat(this.response.getForwardedUrl()).isEqualTo("/500");
}
private void setUpAsyncDispatch() throws Exception {
this.request.setAsyncSupported(true);
this.request.setAsyncStarted(true);
DeferredResult<String> result = new DeferredResult<>();
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.request);
asyncManager.setAsyncWebRequest(
new StandardServletAsyncWebRequest(this.request, this.response));
asyncManager.startDeferredResultProcessing(result);
}
private Map<String, Object> getAttributesForDispatch(String path) {
return this.request.getDispatcher(path).getRequestAttributes();
}
private static final class DispatchRecordingMockHttpServletRequest
extends MockHttpServletRequest {
private final Map<String, AttributeCapturingRequestDispatcher> dispatchers = new HashMap<>();
private DispatchRecordingMockHttpServletRequest() {
super("GET", "/test/path");
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher(
path);
this.dispatchers.put(path, dispatcher);
return dispatcher;
}
private AttributeCapturingRequestDispatcher getDispatcher(String path) {
return this.dispatchers.get(path);
}
private static final class AttributeCapturingRequestDispatcher
extends MockRequestDispatcher {
private final Map<String, Object> requestAttributes = new HashMap<>();
private AttributeCapturingRequestDispatcher(String resource) {
super(resource);
}
@Override
public void forward(ServletRequest request, ServletResponse response) {
captureAttributes(request);
super.forward(request, response);
}
private void captureAttributes(ServletRequest request) {
Enumeration<String> names = request.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
this.requestAttributes.put(name, request.getAttribute(name));
}
}
private Map<String, Object> getRequestAttributes() {
return this.requestAttributes;
}
}
}
}