/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.function.server;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.charset.Charset;
import java.security.Principal;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.concurrent.ConcurrentHashMap;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRange;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.server.WebSession;
/**
* @author Arjen Poutsma
*/
public class MockServerRequest implements ServerRequest {
private final HttpMethod method;
private final URI uri;
private final MockHeaders headers;
private final Object body;
private final Map<String, Object> attributes;
private final MultiValueMap<String, String> queryParams;
private final Map<String, String> pathVariables;
private final WebSession session;
private Principal principal;
private MockServerRequest(HttpMethod method, URI uri,
MockHeaders headers, Object body, Map<String, Object> attributes,
MultiValueMap<String, String> queryParams,
Map<String, String> pathVariables, WebSession session, Principal principal) {
this.method = method;
this.uri = uri;
this.headers = headers;
this.body = body;
this.attributes = attributes;
this.queryParams = queryParams;
this.pathVariables = pathVariables;
this.session = session;
this.principal = principal;
}
@Override
public HttpMethod method() {
return this.method;
}
@Override
public URI uri() {
return this.uri;
}
@Override
public Headers headers() {
return this.headers;
}
@Override
@SuppressWarnings("unchecked")
public <S> S body(BodyExtractor<S, ? super ServerHttpRequest> extractor){
return (S) this.body;
}
@Override
@SuppressWarnings("unchecked")
public <S> S body(BodyExtractor<S, ? super ServerHttpRequest> extractor, Map<String, Object> hints) {
return (S) this.body;
}
@Override
@SuppressWarnings("unchecked")
public <S> Mono<S> bodyToMono(Class<? extends S> elementClass) {
return (Mono<S>) this.body;
}
@Override
@SuppressWarnings("unchecked")
public <S> Flux<S> bodyToFlux(Class<? extends S> elementClass) {
return (Flux<S>) this.body;
}
@SuppressWarnings("unchecked")
@Override
public <S> Optional<S> attribute(String name) {
return Optional.ofNullable((S) this.attributes.get(name));
}
@Override
public Map<String, Object> attributes() {
return this.attributes;
}
@Override
public List<String> queryParams(String name) {
return Collections.unmodifiableList(this.queryParams.get(name));
}
@Override
public Map<String, String> pathVariables() {
return Collections.unmodifiableMap(this.pathVariables);
}
@Override
public Mono<WebSession> session() {
return Mono.justOrEmpty(this.session);
}
@Override
public Mono<? extends Principal> principal() {
return Mono.justOrEmpty(this.principal);
}
public static Builder builder() {
return new BuilderImpl();
}
public interface Builder {
Builder method(HttpMethod method);
Builder uri(URI uri);
Builder header(String key, String value);
Builder headers(HttpHeaders headers);
Builder attribute(String name, Object value);
Builder attributes(Map<String, Object> attributes);
Builder queryParam(String key, String value);
Builder queryParams(MultiValueMap<String, String> queryParams);
Builder pathVariable(String key, String value);
Builder pathVariables(Map<String, String> pathVariables);
Builder session(WebSession session);
Builder session(Principal principal);
MockServerRequest body(Object body);
MockServerRequest build();
}
private static class BuilderImpl implements Builder {
private HttpMethod method = HttpMethod.GET;
private URI uri = URI.create("http://localhost");
private MockHeaders headers = new MockHeaders(new HttpHeaders());
private Object body;
private Map<String, Object> attributes = new ConcurrentHashMap<>();
private MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
private Map<String, String> pathVariables = new LinkedHashMap<>();
private WebSession session;
private Principal principal;
@Override
public Builder method(HttpMethod method) {
Assert.notNull(method, "'method' must not be null");
this.method = method;
return this;
}
@Override
public Builder uri(URI uri) {
Assert.notNull(uri, "'uri' must not be null");
this.uri = uri;
return this;
}
@Override
public Builder header(String key, String value) {
Assert.notNull(key, "'key' must not be null");
Assert.notNull(value, "'value' must not be null");
this.headers.header(key, value);
return this;
}
@Override
public Builder headers(HttpHeaders headers) {
Assert.notNull(headers, "'headers' must not be null");
this.headers = new MockHeaders(headers);
return this;
}
@Override
public Builder attribute(String name, Object value) {
Assert.notNull(name, "'name' must not be null");
Assert.notNull(value, "'value' must not be null");
this.attributes.put(name, value);
return this;
}
@Override
public Builder attributes(Map<String, Object> attributes) {
Assert.notNull(attributes, "'attributes' must not be null");
this.attributes = attributes;
return this;
}
@Override
public Builder queryParam(String key, String value) {
Assert.notNull(key, "'key' must not be null");
Assert.notNull(value, "'value' must not be null");
this.queryParams.add(key, value);
return this;
}
@Override
public Builder queryParams(MultiValueMap<String, String> queryParams) {
Assert.notNull(queryParams, "'queryParams' must not be null");
this.queryParams = queryParams;
return this;
}
@Override
public Builder pathVariable(String key, String value) {
Assert.notNull(key, "'key' must not be null");
Assert.notNull(value, "'value' must not be null");
this.pathVariables.put(key, value);
return this;
}
@Override
public Builder pathVariables(Map<String, String> pathVariables) {
Assert.notNull(pathVariables, "'pathVariables' must not be null");
this.pathVariables = pathVariables;
return this;
}
@Override
public Builder session(WebSession session) {
Assert.notNull(session, "'session' must not be null");
this.session = session;
return this;
}
@Override
public Builder session(Principal principal) {
Assert.notNull(principal, "'principal' must not be null");
this.principal = principal;
return this;
}
@Override
public MockServerRequest body(Object body) {
this.body = body;
return new MockServerRequest(this.method, this.uri, this.headers, this.body,
this.attributes, this.queryParams, this.pathVariables, this.session,
this.principal);
}
@Override
public MockServerRequest build() {
return new MockServerRequest(this.method, this.uri, this.headers, null,
this.attributes, this.queryParams, this.pathVariables, this.session,
this.principal);
}
}
private static class MockHeaders implements Headers {
private final HttpHeaders headers;
public MockHeaders(HttpHeaders headers) {
this.headers = headers;
}
private HttpHeaders delegate() {
return this.headers;
}
public void header(String key, String value) {
this.headers.add(key, value);
}
@Override
public List<MediaType> accept() {
return delegate().getAccept();
}
@Override
public List<Charset> acceptCharset() {
return delegate().getAcceptCharset();
}
@Override
public List<Locale.LanguageRange> acceptLanguage() {
return delegate().getAcceptLanguage();
}
@Override
public OptionalLong contentLength() {
return toOptionalLong(delegate().getContentLength());
}
@Override
public Optional<MediaType> contentType() {
return Optional.ofNullable(delegate().getContentType());
}
@Override
public InetSocketAddress host() {
return delegate().getHost();
}
@Override
public List<HttpRange> range() {
return delegate().getRange();
}
@Override
public List<String> header(String headerName) {
List<String> headerValues = delegate().get(headerName);
return headerValues != null ? headerValues : Collections.emptyList();
}
@Override
public HttpHeaders asHttpHeaders() {
return HttpHeaders.readOnlyHttpHeaders(delegate());
}
private OptionalLong toOptionalLong(long value) {
return value != -1 ? OptionalLong.of(value) : OptionalLong.empty();
}
private Optional<ZonedDateTime> toZonedDateTime(long date) {
if (date != -1) {
Instant instant = Instant.ofEpochMilli(date);
return Optional.of(ZonedDateTime.ofInstant(instant, ZoneId.of("GMT")));
}
else {
return Optional.empty();
}
}
}
}