/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.server.handler;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebExceptionHandler;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* @author Rossen Stoyanchev
*/
public class FilteringWebHandlerTests {
private static Log logger = LogFactory.getLog(FilteringWebHandlerTests.class);
@Test
public void multipleFilters() throws Exception {
TestFilter filter1 = new TestFilter();
TestFilter filter2 = new TestFilter();
TestFilter filter3 = new TestFilter();
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3))
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(filter1.invoked());
assertTrue(filter2.invoked());
assertTrue(filter3.invoked());
assertTrue(targetHandler.invoked());
}
@Test
public void zeroFilters() throws Exception {
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Collections.emptyList())
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(targetHandler.invoked());
}
@Test
public void shortcircuitFilter() throws Exception {
TestFilter filter1 = new TestFilter();
ShortcircuitingFilter filter2 = new ShortcircuitingFilter();
TestFilter filter3 = new TestFilter();
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3))
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(filter1.invoked());
assertTrue(filter2.invoked());
assertFalse(filter3.invoked());
assertFalse(targetHandler.invoked());
}
@Test
public void asyncFilter() throws Exception {
AsyncFilter filter = new AsyncFilter();
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Collections.singletonList(filter))
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ofSeconds(5));
assertTrue(filter.invoked());
assertTrue(targetHandler.invoked());
}
@Test
public void handleErrorFromFilter() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
MockServerHttpResponse response = new MockServerHttpResponse();
TestExceptionHandler exceptionHandler = new TestExceptionHandler();
WebHttpHandlerBuilder.webHandler(new StubWebHandler())
.filters(Collections.singletonList(new ExceptionFilter()))
.exceptionHandlers(Collections.singletonList(exceptionHandler)).build()
.handle(request, response)
.block();
assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode());
assertNotNull(exceptionHandler.ex);
assertEquals("boo", exceptionHandler.ex.getMessage());
}
private static class TestFilter implements WebFilter {
private volatile boolean invoked;
public boolean invoked() {
return this.invoked;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
this.invoked = true;
return doFilter(exchange, chain);
}
public Mono<Void> doFilter(ServerWebExchange exchange, WebFilterChain chain) {
return chain.filter(exchange);
}
}
private static class ShortcircuitingFilter extends TestFilter {
@Override
public Mono<Void> doFilter(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.empty();
}
}
private static class AsyncFilter extends TestFilter {
@Override
public Mono<Void> doFilter(ServerWebExchange exchange, WebFilterChain chain) {
return doAsyncWork().flatMap(asyncResult -> {
logger.debug("Async result: " + asyncResult);
return chain.filter(exchange);
});
}
private Mono<String> doAsyncWork() {
return Mono.delay(Duration.ofMillis(100L)).map(l -> "123");
}
}
private static class ExceptionFilter implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.error(new IllegalStateException("boo"));
}
}
private static class TestExceptionHandler implements WebExceptionHandler {
private Throwable ex;
@Override
public Mono<Void> handle(ServerWebExchange exchange, Throwable ex) {
this.ex = ex;
return Mono.error(ex);
}
}
private static class StubWebHandler implements WebHandler {
private volatile boolean invoked;
public boolean invoked() {
return this.invoked;
}
@Override
public Mono<Void> handle(ServerWebExchange exchange) {
logger.trace("StubHandler invoked.");
this.invoked = true;
return Mono.empty();
}
}
}