/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.function.server;
import java.net.URI;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.server.WebSession;
import org.springframework.web.util.UriUtils;
import org.springframework.web.util.pattern.PathPattern;
import org.springframework.web.util.pattern.PathPatternParser;
/**
* Implementations of {@link RequestPredicate} that implement various useful
* request matching operations, such as matching based on path, HTTP method, etc.
*
* @author Arjen Poutsma
* @since 5.0
*/
public abstract class RequestPredicates {
private static final Log logger = LogFactory.getLog(RequestPredicates.class);
private static final PathPatternParser DEFAULT_PATTERN_PARSER = new PathPatternParser();
/**
* Return a {@code RequestPredicate} that always matches.
* @return a predicate that always matches
*/
public static RequestPredicate all() {
return request -> true;
}
/**
* Return a {@code RequestPredicate} that tests against the given HTTP method.
* @param httpMethod the HTTP method to match to
* @return a predicate that tests against the given HTTP method
*/
public static RequestPredicate method(HttpMethod httpMethod) {
return new HttpMethodPredicate(httpMethod);
}
/**
* Return a {@code RequestPredicate} that tests the request path against the given path pattern.
* @param pattern the pattern to match to
* @return a predicate that tests against the given path pattern
*/
public static RequestPredicate path(String pattern) {
Assert.notNull(pattern, "'pattern' must not be null");
return pathPredicates(DEFAULT_PATTERN_PARSER).apply(pattern);
}
/**
* Return a function that creates new path-matching {@code RequestPredicates} from pattern
* Strings using the given {@link PathPatternParser}. This method can be used to specify a
* non-default, customized {@code PathPatternParser} when resolving path patterns.
* @param patternParser the parser used to parse patterns given to the returned function
* @return a function that resolves patterns Strings into path-matching
* {@code RequestPredicate}s
*/
public static Function<String, RequestPredicate> pathPredicates(PathPatternParser patternParser) {
Assert.notNull(patternParser, "'patternParser' must not be null");
return pattern -> new PathPatternPredicate(patternParser.parse(pattern));
}
/**
* Return a {@code RequestPredicate} that tests the request's headers against the given headers predicate.
* @param headersPredicate a predicate that tests against the request headers
* @return a predicate that tests against the given header predicate
*/
public static RequestPredicate headers(Predicate<ServerRequest.Headers> headersPredicate) {
return new HeadersPredicate(headersPredicate);
}
/**
* Return a {@code RequestPredicate} that tests if the request's
* {@linkplain ServerRequest.Headers#contentType() content type} is
* {@linkplain MediaType#includes(MediaType) included} by any of the given media types.
* @param mediaTypes the media types to match the request's content type against
* @return a predicate that tests the request's content type against the given media types
*/
public static RequestPredicate contentType(MediaType... mediaTypes) {
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
Set<MediaType> mediaTypeSet = new HashSet<>(Arrays.asList(mediaTypes));
return headers(new Predicate<ServerRequest.Headers>() {
@Override
public boolean test(ServerRequest.Headers headers) {
MediaType contentType =
headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = mediaTypeSet.stream()
.anyMatch(mediaType -> mediaType.includes(contentType));
traceMatch("Content-Type", mediaTypeSet, contentType, match);
return match;
}
@Override
public String toString() {
return String.format("Content-Type: %s", mediaTypeSet);
}
});
}
/**
* Return a {@code RequestPredicate} that tests if the request's
* {@linkplain ServerRequest.Headers#accept() accept} header is
* {@linkplain MediaType#isCompatibleWith(MediaType) compatible} with any of the given media types.
* @param mediaTypes the media types to match the request's accept header against
* @return a predicate that tests the request's accept header against the given media types
*/
public static RequestPredicate accept(MediaType... mediaTypes) {
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
Set<MediaType> mediaTypeSet = new HashSet<>(Arrays.asList(mediaTypes));
return headers(new Predicate<ServerRequest.Headers>() {
@Override
public boolean test(ServerRequest.Headers headers) {
List<MediaType> acceptedMediaTypes = headers.accept();
if (acceptedMediaTypes.isEmpty()) {
acceptedMediaTypes = Collections.singletonList(MediaType.ALL);
}
else {
MediaType.sortBySpecificityAndQuality(acceptedMediaTypes);
}
boolean match = acceptedMediaTypes.stream()
.anyMatch(acceptedMediaType -> mediaTypeSet.stream()
.anyMatch(acceptedMediaType::isCompatibleWith));
traceMatch("Accept", mediaTypeSet, acceptedMediaTypes, match);
return match;
}
@Override
public String toString() {
return String.format("Accept: %s", mediaTypeSet);
}
});
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code GET}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is GET and if the given pattern
* matches against the request path
*/
public static RequestPredicate GET(String pattern) {
return method(HttpMethod.GET).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code HEAD}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is HEAD and if the given pattern
* matches against the request path
*/
public static RequestPredicate HEAD(String pattern) {
return method(HttpMethod.HEAD).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code POST}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is POST and if the given pattern
* matches against the request path
*/
public static RequestPredicate POST(String pattern) {
return method(HttpMethod.POST).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code PUT}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is PUT and if the given pattern
* matches against the request path
*/
public static RequestPredicate PUT(String pattern) {
return method(HttpMethod.PUT).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code PATCH}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is PATCH and if the given pattern
* matches against the request path
*/
public static RequestPredicate PATCH(String pattern) {
return method(HttpMethod.PATCH).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code DELETE}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is DELETE and if the given pattern
* matches against the request path
*/
public static RequestPredicate DELETE(String pattern) {
return method(HttpMethod.DELETE).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if request's HTTP method is {@code OPTIONS}
* and the given {@code pattern} matches against the request path.
* @param pattern the path pattern to match against
* @return a predicate that matches if the request method is OPTIONS and if the given pattern
* matches against the request path
*/
public static RequestPredicate OPTIONS(String pattern) {
return method(HttpMethod.OPTIONS).and(path(pattern));
}
/**
* Return a {@code RequestPredicate} that matches if the request's path has the given extension.
* @param extension the path extension to match against, ignoring case
* @return a predicate that matches if the request's path has the given file extension
*/
public static RequestPredicate pathExtension(String extension) {
Assert.notNull(extension, "'extension' must not be null");
return pathExtension(pathExtension -> {
boolean match = extension.equalsIgnoreCase(pathExtension);
traceMatch("Extension", extension, pathExtension, match);
return match;
});
}
/**
* Return a {@code RequestPredicate} that matches if the request's path matches the given
* predicate.
* @param extensionPredicate the predicate to test against the request path extension
* @return a predicate that matches if the given predicate matches against the request's path
* file extension
*/
public static RequestPredicate pathExtension(Predicate<String> extensionPredicate) {
Assert.notNull(extensionPredicate, "'extensionPredicate' must not be null");
return request -> {
String pathExtension = UriUtils.extractFileExtension(request.path());
return extensionPredicate.test(pathExtension);
};
}
/**
* Return a {@code RequestPredicate} that tests the request's query parameter of the given name
* against the given predicate.
* @param name the name of the query parameter to test against
* @param predicate predicate to test against the query parameter value
* @return a predicate that matches the given predicate against the query parameter of the given name
* @see ServerRequest#queryParam(String)
*/
public static RequestPredicate queryParam(String name, Predicate<String> predicate) {
return request -> {
Optional<String> s = request.queryParam(name);
return s.filter(predicate).isPresent();
};
}
private static void traceMatch(String prefix, Object desired, Object actual, boolean match) {
if (logger.isTraceEnabled()) {
String message = String.format("%s \"%s\" %s against value \"%s\"",
prefix, desired, match ? "matches" : "does not match", actual);
logger.trace(message);
}
}
private static class HttpMethodPredicate implements RequestPredicate {
private final HttpMethod httpMethod;
public HttpMethodPredicate(HttpMethod httpMethod) {
Assert.notNull(httpMethod, "'httpMethod' must not be null");
this.httpMethod = httpMethod;
}
@Override
public boolean test(ServerRequest request) {
boolean match = this.httpMethod == request.method();
traceMatch("Method", this.httpMethod, request.method(), match);
return match;
}
@Override
public String toString() {
return this.httpMethod.toString();
}
}
private static class PathPatternPredicate implements RequestPredicate {
private final PathPattern pattern;
public PathPatternPredicate(PathPattern pattern) {
Assert.notNull(pattern, "'pattern' must not be null");
this.pattern = pattern;
}
@Override
public boolean test(ServerRequest request) {
String path = request.path();
boolean match = this.pattern.matches(path);
traceMatch("Pattern", this.pattern.getPatternString(), path, match);
if (match) {
mergeTemplateVariables(request, this.pattern.matchAndExtract(request.path()));
return true;
}
else {
return false;
}
}
@Override
public Optional<ServerRequest> nest(ServerRequest request) {
return Optional.ofNullable(this.pattern.getPathRemaining(request.path()))
.map(info -> {
mergeTemplateVariables(request, info.getMatchingVariables());
String path = info.getPathRemaining();
if (!path.startsWith("/")) {
path = "/" + path;
}
return new SubPathServerRequestWrapper(request, path);
});
}
private void mergeTemplateVariables(ServerRequest request, Map<String, String> variables) {
if (!variables.isEmpty()) {
Map<String, String> oldVariables = request.pathVariables();
Map<String, String> mergedVariables = new LinkedHashMap<>(oldVariables);
mergedVariables.putAll(variables);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(mergedVariables));
}
}
@Override
public String toString() {
return this.pattern.getPatternString();
}
}
private static class HeadersPredicate implements RequestPredicate {
private final Predicate<ServerRequest.Headers> headersPredicate;
public HeadersPredicate(Predicate<ServerRequest.Headers> headersPredicate) {
Assert.notNull(headersPredicate, "'headersPredicate' must not be null");
this.headersPredicate = headersPredicate;
}
@Override
public boolean test(ServerRequest request) {
return this.headersPredicate.test(request.headers());
}
@Override
public String toString() {
return this.headersPredicate.toString();
}
}
static class AndRequestPredicate implements RequestPredicate {
private final RequestPredicate left;
private final RequestPredicate right;
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
this.left = left;
this.right = right;
}
@Override
public boolean test(ServerRequest t) {
return this.left.test(t) && this.right.test(t);
}
@Override
public Optional<ServerRequest> nest(ServerRequest request) {
return this.left.nest(request).flatMap(this.right::nest);
}
@Override
public String toString() {
return String.format("(%s && %s)", this.left, this.right);
}
}
static class OrRequestPredicate implements RequestPredicate {
private final RequestPredicate left;
private final RequestPredicate right;
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
this.left = left;
this.right = right;
}
@Override
public boolean test(ServerRequest t) {
return this.left.test(t) || this.right.test(t);
}
@Override
public Optional<ServerRequest> nest(ServerRequest request) {
Optional<ServerRequest> leftResult = this.left.nest(request);
if (leftResult.isPresent()) {
return leftResult;
}
else {
return this.right.nest(request);
}
}
@Override
public String toString() {
return String.format("(%s || %s)", this.left, this.right);
}
}
private static class SubPathServerRequestWrapper implements ServerRequest {
private final ServerRequest request;
private final String subPath;
public SubPathServerRequestWrapper(ServerRequest request, String subPath) {
this.request = request;
this.subPath = subPath;
}
@Override
public HttpMethod method() {
return this.request.method();
}
@Override
public URI uri() {
return this.request.uri();
}
@Override
public String path() {
return this.subPath;
}
@Override
public Headers headers() {
return this.request.headers();
}
@Override
public <T> T body(BodyExtractor<T, ? super ServerHttpRequest> extractor) {
return this.request.body(extractor);
}
@Override
public <T> T body(BodyExtractor<T, ? super ServerHttpRequest> extractor, Map<String, Object> hints) {
return this.request.body(extractor, hints);
}
@Override
public <T> Mono<T> bodyToMono(Class<? extends T> elementClass) {
return this.request.bodyToMono(elementClass);
}
@Override
public <T> Flux<T> bodyToFlux(Class<? extends T> elementClass) {
return this.request.bodyToFlux(elementClass);
}
@Override
public <T> Optional<T> attribute(String name) {
return this.request.attribute(name);
}
@Override
public Map<String, Object> attributes() {
return this.request.attributes();
}
@Override
public Optional<String> queryParam(String name) {
return this.request.queryParam(name);
}
@Override
public List<String> queryParams(String name) {
return this.request.queryParams(name);
}
@Override
public String pathVariable(String name) {
return this.request.pathVariable(name);
}
@Override
public Map<String, String> pathVariables() {
return this.request.pathVariables();
}
@Override
public Mono<WebSession> session() {
return this.request.session();
}
@Override
public Mono<? extends Principal> principal() {
return this.request.principal();
}
@Override
public String toString() {
return method() + " " + path();
}
}
}