/* * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.reactive.socket.client; import java.net.URI; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.reactivex.netty.protocol.http.client.HttpClient; import io.reactivex.netty.protocol.http.client.HttpClientRequest; import io.reactivex.netty.protocol.http.ws.WebSocketConnection; import io.reactivex.netty.protocol.http.ws.client.WebSocketRequest; import io.reactivex.netty.protocol.http.ws.client.WebSocketResponse; import io.reactivex.netty.threads.RxEventLoopProvider; import reactor.core.publisher.Mono; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; import rx.Observable; import rx.RxReactiveStreams; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession; import static io.reactivex.netty.protocol.http.HttpHandlerNames.WsClientDecoder; /** * {@link WebSocketClient} implementation for use with RxNetty. * For internal use within the framework. * * <p><strong>Note: </strong> RxNetty {@link HttpClient} instances require a host * and port in order to be created. Hence it is not possible to configure a * single {@code HttpClient} instance to use upfront. Instead the constructors * accept a function for obtaining client instances when establishing a * connection to a specific URI. By default new instances are created per * connection with a shared Netty {@code EventLoopGroup}. See constructors for * more details. * * @author Rossen Stoyanchev * @since 5.0 */ public class RxNettyWebSocketClient extends WebSocketClientSupport implements WebSocketClient { private final Function<URI, HttpClient<ByteBuf, ByteBuf>> httpClientProvider; /** * Default constructor that creates {@code HttpClient} instances via * {@link HttpClient#newClient(String, int)} using port 80 or 443 depending * on the target URL scheme. * * <p><strong>Note: </strong> By default a new {@link HttpClient} instance * is created per WebSocket connection. Those instances will share a global * {@code EventLoopGroup} that RxNetty obtains via * {@link RxEventLoopProvider#globalClientEventLoop(boolean)}. */ public RxNettyWebSocketClient() { this(RxNettyWebSocketClient::getDefaultHttpClientProvider); } /** * Constructor with a function to use to obtain {@link HttpClient} instances. */ public RxNettyWebSocketClient(Function<URI, HttpClient<ByteBuf, ByteBuf>> httpClientProvider) { this.httpClientProvider = httpClientProvider; } private static HttpClient<ByteBuf, ByteBuf> getDefaultHttpClientProvider(URI url) { boolean secure = "wss".equals(url.getScheme()); int port = (url.getPort() > 0 ? url.getPort() : secure ? 443 : 80); return HttpClient.newClient(url.getHost(), port); } /** * Return the configured {@link HttpClient} provider depending on which * constructor was used. */ public Function<URI, HttpClient<ByteBuf, ByteBuf>> getHttpClientProvider() { return this.httpClientProvider; } /** * Return an {@link HttpClient} instance to use to connect to the given URI. * The default implementation invokes the {@link #getHttpClientProvider()} * provider} function created or supplied at construction time. * @param url the full URL of the WebSocket endpoint. */ public HttpClient<ByteBuf, ByteBuf> getHttpClient(URI url) { return this.httpClientProvider.apply(url); } @Override public Mono<Void> execute(URI url, WebSocketHandler handler) { return execute(url, new HttpHeaders(), handler); } @Override public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) { Observable<Void> completion = executeInternal(url, headers, handler); return Mono.from(RxReactiveStreams.toPublisher(completion)); } @SuppressWarnings("cast") private Observable<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { String[] protocols = beforeHandshake(url, headers, handler); return createRequest(url, headers, protocols) .flatMap(response -> { Observable<WebSocketConnection> conn = response.getWebSocketConnection(); // following cast is necessary to enable compilation on Eclipse 4.6 return (Observable<Tuple2<WebSocketResponse<ByteBuf>, WebSocketConnection>>) Observable.zip(Observable.just(response), conn, Tuples::of); }) .flatMap(tuple -> { WebSocketResponse<ByteBuf> response = tuple.getT1(); WebSocketConnection conn = tuple.getT2(); HandshakeInfo info = afterHandshake(url, toHttpHeaders(response)); ByteBufAllocator allocator = response.unsafeNettyChannel().alloc(); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); RxNettyWebSocketSession session = new RxNettyWebSocketSession(conn, info, factory); session.aggregateFrames(response.unsafeNettyChannel(), WsClientDecoder.getName()); return RxReactiveStreams.toObservable(handler.handle(session)); }); } private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, String[] protocols) { String query = url.getRawQuery(); String requestUrl = url.getRawPath() + (query != null ? "?" + query : ""); HttpClientRequest<ByteBuf, ByteBuf> request = getHttpClient(url).createGet(requestUrl); if (!headers.isEmpty()) { Map<String, List<Object>> map = new HashMap<>(headers.size()); headers.forEach((key, values) -> map.put(key, new ArrayList<>(headers.get(key)))); request = request.setHeaders(map); } return (ObjectUtils.isEmpty(protocols) ? request.requestWebSocketUpgrade() : request.requestWebSocketUpgrade().requestSubProtocols(protocols)); } private HttpHeaders toHttpHeaders(WebSocketResponse<ByteBuf> response) { HttpHeaders headers = new HttpHeaders(); response.headerIterator().forEachRemaining(entry -> { String name = entry.getKey().toString(); headers.put(name, response.getAllHeaderValues(name)); }); return headers; } }