/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.test.web.reactive.server; import java.net.URI; import java.nio.charset.Charset; import java.time.Duration; import java.time.ZonedDateTime; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.UnaryOperator; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.ResolvableType; import org.springframework.core.io.ByteArrayResource; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.client.reactive.ClientHttpConnector; import org.springframework.http.client.reactive.ClientHttpRequest; import org.springframework.test.util.JsonExpectationsHelper; import org.springframework.util.Assert; import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriBuilder; import static java.nio.charset.StandardCharsets.UTF_8; import static org.springframework.test.util.AssertionErrors.assertEquals; import static org.springframework.test.util.AssertionErrors.assertTrue; import static org.springframework.web.reactive.function.BodyExtractors.toFlux; import static org.springframework.web.reactive.function.BodyExtractors.toMono; /** * Default implementation of {@link WebTestClient}. * * @author Rossen Stoyanchev * @since 5.0 */ class DefaultWebTestClient implements WebTestClient { private final WebClient webClient; private final WiretapConnector wiretapConnector; private final ExchangeMutatingWebFilter exchangeMutatingWebFilter; private final Duration timeout; private final AtomicLong requestIndex = new AtomicLong(); DefaultWebTestClient(WebClient.Builder webClientBuilder, ClientHttpConnector connector, ExchangeMutatingWebFilter filter, Duration timeout) { Assert.notNull(webClientBuilder, "WebClient.Builder is required"); this.wiretapConnector = new WiretapConnector(connector); this.webClient = webClientBuilder.clientConnector(this.wiretapConnector).build(); this.exchangeMutatingWebFilter = (filter != null ? filter : new ExchangeMutatingWebFilter()); this.timeout = (timeout != null ? timeout : Duration.ofSeconds(5)); } private DefaultWebTestClient(DefaultWebTestClient webTestClient, ExchangeFilterFunction filter) { this.webClient = webTestClient.webClient.filter(filter); this.wiretapConnector = webTestClient.wiretapConnector; this.exchangeMutatingWebFilter = webTestClient.exchangeMutatingWebFilter; this.timeout = webTestClient.timeout; } private Duration getTimeout() { return this.timeout; } @Override public UriSpec<RequestHeadersSpec<?>> get() { return toUriSpec(wc -> wc.method(HttpMethod.GET)); } @Override public UriSpec<RequestHeadersSpec<?>> head() { return toUriSpec(wc -> wc.method(HttpMethod.HEAD)); } @Override public UriSpec<RequestBodySpec> post() { return toUriSpec(wc -> wc.method(HttpMethod.POST)); } @Override public UriSpec<RequestBodySpec> put() { return toUriSpec(wc -> wc.method(HttpMethod.PUT)); } @Override public UriSpec<RequestBodySpec> patch() { return toUriSpec(wc -> wc.method(HttpMethod.PATCH)); } @Override public UriSpec<RequestHeadersSpec<?>> delete() { return toUriSpec(wc -> wc.method(HttpMethod.DELETE)); } @Override public UriSpec<RequestHeadersSpec<?>> options() { return toUriSpec(wc -> wc.method(HttpMethod.OPTIONS)); } private <S extends RequestHeadersSpec<?>> UriSpec<S> toUriSpec( Function<WebClient, WebClient.UriSpec<WebClient.RequestBodySpec>> function) { return new DefaultUriSpec<>(function.apply(this.webClient)); } @Override public WebTestClient filter(ExchangeFilterFunction filter) { return new DefaultWebTestClient(this, filter); } @Override public WebTestClient exchangeMutator(UnaryOperator<ServerWebExchange> mutator) { Assert.notNull(this.exchangeMutatingWebFilter, "This option is applicable only for tests without an actual running server"); return filter((request, next) -> { String requestId = request.headers().getFirst(WiretapConnector.REQUEST_ID_HEADER_NAME); Assert.notNull(requestId, "No request-id header"); this.exchangeMutatingWebFilter.registerPerRequestMutator(requestId, mutator); return next.exchange(request); }); } @SuppressWarnings("unchecked") private class DefaultUriSpec<S extends RequestHeadersSpec<?>> implements UriSpec<S> { private final WebClient.UriSpec<WebClient.RequestBodySpec> uriSpec; DefaultUriSpec(WebClient.UriSpec<WebClient.RequestBodySpec> spec) { this.uriSpec = spec; } @Override public S uri(URI uri) { return (S) new DefaultRequestBodySpec(this.uriSpec.uri(uri)); } @Override public S uri(String uriTemplate, Object... uriVariables) { return (S) new DefaultRequestBodySpec(this.uriSpec.uri(uriTemplate, uriVariables)); } @Override public S uri(String uriTemplate, Map<String, ?> uriVariables) { return (S) new DefaultRequestBodySpec(this.uriSpec.uri(uriTemplate, uriVariables)); } @Override public S uri(Function<UriBuilder, URI> uriBuilder) { return (S) new DefaultRequestBodySpec(this.uriSpec.uri(uriBuilder)); } } private class DefaultRequestBodySpec implements RequestBodySpec { private final WebClient.RequestBodySpec bodySpec; private final String requestId; DefaultRequestBodySpec(WebClient.RequestBodySpec spec) { this.bodySpec = spec; this.requestId = String.valueOf(requestIndex.incrementAndGet()); this.bodySpec.header(WiretapConnector.REQUEST_ID_HEADER_NAME, this.requestId); } @Override public RequestBodySpec header(String headerName, String... headerValues) { this.bodySpec.header(headerName, headerValues); return this; } @Override public RequestBodySpec headers(HttpHeaders headers) { this.bodySpec.headers(headers); return this; } @Override public RequestBodySpec accept(MediaType... acceptableMediaTypes) { this.bodySpec.accept(acceptableMediaTypes); return this; } @Override public RequestBodySpec acceptCharset(Charset... acceptableCharsets) { this.bodySpec.acceptCharset(acceptableCharsets); return this; } @Override public RequestBodySpec contentType(MediaType contentType) { this.bodySpec.contentType(contentType); return this; } @Override public RequestBodySpec contentLength(long contentLength) { this.bodySpec.contentLength(contentLength); return this; } @Override public RequestBodySpec cookie(String name, String value) { this.bodySpec.cookie(name, value); return this; } @Override public RequestBodySpec cookies(MultiValueMap<String, String> cookies) { this.bodySpec.cookies(cookies); return this; } @Override public RequestBodySpec ifModifiedSince(ZonedDateTime ifModifiedSince) { this.bodySpec.ifModifiedSince(ifModifiedSince); return this; } @Override public RequestBodySpec ifNoneMatch(String... ifNoneMatches) { this.bodySpec.ifNoneMatch(ifNoneMatches); return this; } @Override public ResponseSpec exchange() { return toResponseSpec(this.bodySpec.exchange()); } @Override public RequestHeadersSpec<?> body(BodyInserter<?, ? super ClientHttpRequest> inserter) { this.bodySpec.body(inserter); return this; } @Override public <T, S extends Publisher<T>> RequestHeadersSpec<?> body(S publisher, Class<T> elementClass) { this.bodySpec.body(publisher, elementClass); return this; } @Override public RequestHeadersSpec<?> syncBody(Object body) { this.bodySpec.syncBody(body); return this; } private DefaultResponseSpec toResponseSpec(Mono<ClientResponse> mono) { ClientResponse clientResponse = mono.block(getTimeout()); ExchangeResult exchangeResult = wiretapConnector.claimRequest(this.requestId); return new DefaultResponseSpec(exchangeResult, clientResponse, getTimeout()); } } private static class UndecodedExchangeResult extends ExchangeResult { private final ClientResponse response; private final Duration timeout; UndecodedExchangeResult(ExchangeResult result, ClientResponse response, Duration timeout) { super(result); this.response = response; this.timeout = timeout; } @SuppressWarnings("unchecked") public <T> EntityExchangeResult<T> decode(ResolvableType bodyType) { T body = (T) this.response.body(toMono(bodyType)).block(this.timeout); return new EntityExchangeResult<>(this, body); } public <T> EntityExchangeResult<List<T>> decodeToList(ResolvableType elementType) { Flux<T> flux = this.response.body(toFlux(elementType)); List<T> body = flux.collectList().block(this.timeout); return new EntityExchangeResult<>(this, body); } public <T> FluxExchangeResult<T> decodeToFlux(ResolvableType elementType) { Flux<T> body = this.response.body(toFlux(elementType)); return new FluxExchangeResult<>(this, body, this.timeout); } public EntityExchangeResult<byte[]> decodeToByteArray() { ByteArrayResource resource = this.response.body(toMono(ByteArrayResource.class)).block(this.timeout); byte[] body = (resource != null ? resource.getByteArray() : null); return new EntityExchangeResult<>(this, body); } } private static class DefaultResponseSpec implements ResponseSpec { private final UndecodedExchangeResult result; DefaultResponseSpec(ExchangeResult result, ClientResponse response, Duration timeout) { this.result = new UndecodedExchangeResult(result, response, timeout); } @Override public StatusAssertions expectStatus() { return new StatusAssertions(this.result, this); } @Override public HeaderAssertions expectHeader() { return new HeaderAssertions(this.result, this); } @Override public <B> BodySpec<B, ?> expectBody(Class<B> bodyType) { return expectBody(ResolvableType.forClass(bodyType)); } @Override public <B> BodySpec<B, ?> expectBody(ResolvableType bodyType) { return new DefaultBodySpec<>(this.result.decode(bodyType)); } @Override public <E> ListBodySpec<E> expectBodyList(Class<E> elementType) { return expectBodyList(ResolvableType.forClass(elementType)); } @Override public <E> ListBodySpec<E> expectBodyList(ResolvableType elementType) { return new DefaultListBodySpec<>(this.result.decodeToList(elementType)); } @Override public BodyContentSpec expectBody() { return new DefaultBodyContentSpec(this.result.decodeToByteArray()); } @Override public <T> FluxExchangeResult<T> returnResult(Class<T> elementType) { return returnResult(ResolvableType.forClass(elementType)); } @Override public <T> FluxExchangeResult<T> returnResult(ResolvableType elementType) { return this.result.decodeToFlux(elementType); } } private static class DefaultBodySpec<B, S extends BodySpec<B, S>> implements BodySpec<B, S> { private final EntityExchangeResult<B> result; DefaultBodySpec(EntityExchangeResult<B> result) { this.result = result; } protected EntityExchangeResult<B> getResult() { return this.result; } @Override public <T extends S> T isEqualTo(B expected) { B actual = this.result.getResponseBody(); this.result.assertWithDiagnostics(() -> assertEquals("Response body", expected, actual)); return self(); } @Override public <T extends S> T consumeWith(Consumer<B> consumer) { B actual = this.result.getResponseBody(); this.result.assertWithDiagnostics(() -> consumer.accept(actual)); return self(); } @SuppressWarnings("unchecked") private <T extends S> T self() { return (T) this; } @Override public EntityExchangeResult<B> returnResult() { return this.result; } } private static class DefaultListBodySpec<E> extends DefaultBodySpec<List<E>, ListBodySpec<E>> implements ListBodySpec<E> { DefaultListBodySpec(EntityExchangeResult<List<E>> result) { super(result); } @Override public ListBodySpec<E> hasSize(int size) { List<E> actual = getResult().getResponseBody(); String message = "Response body does not contain " + size + " elements"; getResult().assertWithDiagnostics(() -> assertEquals(message, size, actual.size())); return this; } @Override @SuppressWarnings("unchecked") public ListBodySpec<E> contains(E... elements) { List<E> expected = Arrays.asList(elements); List<E> actual = getResult().getResponseBody(); String message = "Response body does not contain " + expected; getResult().assertWithDiagnostics(() -> assertTrue(message, actual.containsAll(expected))); return this; } @Override @SuppressWarnings("unchecked") public ListBodySpec<E> doesNotContain(E... elements) { List<E> expected = Arrays.asList(elements); List<E> actual = getResult().getResponseBody(); String message = "Response body should have contained " + expected; getResult().assertWithDiagnostics(() -> assertTrue(message, !actual.containsAll(expected))); return this; } @Override public EntityExchangeResult<List<E>> returnResult() { return getResult(); } } private static class DefaultBodyContentSpec implements BodyContentSpec { private final EntityExchangeResult<byte[]> result; private final boolean isEmpty; DefaultBodyContentSpec(EntityExchangeResult<byte[]> result) { this.result = result; this.isEmpty = (result.getResponseBody() == null); } @Override public EntityExchangeResult<Void> isEmpty() { this.result.assertWithDiagnostics(() -> assertTrue("Expected empty body", this.isEmpty)); return new EntityExchangeResult<>(this.result, null); } @Override public BodyContentSpec json(String json) { this.result.assertWithDiagnostics(() -> { try { new JsonExpectationsHelper().assertJsonEqual(json, getBodyAsString()); } catch (Exception ex) { throw new AssertionError("JSON parsing error", ex); } }); return this; } @Override public JsonPathAssertions jsonPath(String expression, Object... args) { return new JsonPathAssertions(this, expression, args); } @Override public BodyContentSpec consumeAsStringWith(Consumer<String> consumer) { this.result.assertWithDiagnostics(() -> consumer.accept(getBodyAsString())); return this; } private String getBodyAsString() { if (this.isEmpty) { return null; } MediaType mediaType = this.result.getResponseHeaders().getContentType(); Charset charset = Optional.ofNullable(mediaType).map(MimeType::getCharset).orElse(UTF_8); return new String(this.result.getResponseBody(), charset); } @Override public BodyContentSpec consumeWith(Consumer<byte[]> consumer) { this.result.assertWithDiagnostics(() -> consumer.accept(this.result.getResponseBody())); return this; } @Override public EntityExchangeResult<byte[]> returnResult() { return this.result; } } }