/* * Copyright 2016 Netflix, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package io.reactivex.netty.protocol.http.ws.server; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder; import io.netty.handler.codec.http.websocketx.WebSocket13FrameEncoder; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrameDecoder; import io.netty.handler.codec.http.websocketx.WebSocketFrameEncoder; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker07; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker08; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker13; import io.netty.util.CharsetUtil; import io.reactivex.netty.channel.Connection; import io.reactivex.netty.channel.MarkAwarePipeline; import io.reactivex.netty.protocol.http.internal.AbstractHttpConnectionBridge; import io.reactivex.netty.protocol.http.server.HttpServerRequest; import io.reactivex.netty.protocol.http.server.HttpServerResponse; import io.reactivex.netty.protocol.http.ws.WebSocketConnection; import io.reactivex.netty.protocol.http.ws.internal.WsUtils; import io.reactivex.netty.protocol.http.ws.server.V7to13Handshaker.State; import rx.Observable; import rx.Observable.OnSubscribe; import rx.Subscriber; import rx.functions.Action0; import rx.functions.Func0; import static io.netty.handler.codec.http.HttpHeaderNames.*; import static io.netty.handler.codec.http.HttpHeaderValues.*; import static io.reactivex.netty.protocol.http.HttpHandlerNames.*; /** * A websocket upgrade handler for upgrading to WebSocket versions 7 to 13. This handler listens for * {@link WebSocket7To13UpgradeAcceptedEvent} and upon recieving such an event, it sets up the * {@link WebSocketConnection} to hand it over to the associated {@link WebSocketHandler} */ public final class Ws7To13UpgradeHandler extends ChannelDuplexHandler { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof WebSocket7To13UpgradeAcceptedEvent) { final WebSocket7To13UpgradeAcceptedEvent wsUpEvt = (WebSocket7To13UpgradeAcceptedEvent) evt; final State state = wsUpEvt.state; final Subscriber<? super Void> subscriber = wsUpEvt.acceptUpgradeSubscriber; String errorIfAny = configureResponseForWs(state); if (null != errorIfAny) { subscriber.onError(new IllegalStateException(errorIfAny)); return; } final MarkAwarePipeline pipeline = state.getUpgradeResponse().unsafeConnection() .getResettableChannelPipeline(); @SuppressWarnings("unchecked") final Connection<WebSocketFrame, WebSocketFrame> wsConn = (Connection<WebSocketFrame, WebSocketFrame>) wsUpEvt.state.getUpgradeResponse().unsafeConnection(); wsUpEvt.request.discardContent() .onErrorResumeNext(Observable.<Void>empty()) // In case, the request content was read, ignore. .concatWith(state.getUpgradeResponse().setTransferEncodingChunked().sendHeaders()) .doOnCompleted(new Action0() { @Override public void call() { /*We are no more talking HTTP*/ pipeline.remove(HttpServerEncoder.getName()); pipeline.remove(HttpServerDecoder.getName()); pipeline.channel().attr(AbstractHttpConnectionBridge.CONNECTION_UPGRADED).set(true); } }) .concatWith(Observable.defer(new Func0<Observable<Void>>() { @Override public Observable<Void> call() { return wsUpEvt.handler.handle(new WebSocketConnection(wsConn)); } })) .concatWith(Observable.create(new OnSubscribe<Void>() { @Override public void call(Subscriber<? super Void> sub) { /* * In this case, the client did not send a close frame but the server end processing * is over, so we should send a close frame to indicate closure from server. */ if (wsConn.unsafeNettyChannel().isOpen()) { wsConn.write(Observable.<WebSocketFrame>just(new CloseWebSocketFrame())) .concatWith(wsConn.close()) .unsafeSubscribe(sub); } } })) .unsafeSubscribe(subscriber); /*Unsafe as the subscriber is coming from the user.*/ } ctx.fireUserEventTriggered(evt); } private static String configureResponseForWs(State state) { String acceptGuid; switch (state.getVersion()) { case V07: acceptGuid = WebSocketServerHandshaker07.WEBSOCKET_07_ACCEPT_GUID; break; case V08: acceptGuid = WebSocketServerHandshaker08.WEBSOCKET_08_ACCEPT_GUID; break; case V13: acceptGuid = WebSocketServerHandshaker13.WEBSOCKET_13_ACCEPT_GUID; break; default: return "Unsupported web socket version: " + state.getVersion(); } WebSocketFrameEncoder wsEncoder = new WebSocket13FrameEncoder(false /*servers should set this to false.*/); WebSocketFrameDecoder wsDecoder = new WebSocket13FrameDecoder(true/*servers should set this to true.*/, state.isAllowExtensions(), state.getMaxFramePayloadLength(), true); final HttpServerResponse<?> upgradeResponse = state.getUpgradeResponse(); final MarkAwarePipeline pipeline = upgradeResponse.unsafeConnection().getResettableChannelPipeline(); ChannelHandlerContext httpDecoderCtx = pipeline.context(HttpServerDecoder.getName()); if (null == httpDecoderCtx) { return "No HTTP decoder found, can not upgrade to WebSocket."; } ChannelHandlerContext httpEncoderCtx = pipeline.context(HttpServerEncoder.getName()); if (null == httpEncoderCtx) { return "No HTTP encoder found, can not upgrade to WebSocket."; } pipeline.addAfter(httpDecoderCtx.name(), WsServerDecoder.getName(), wsDecoder); pipeline.addBefore(httpEncoderCtx.name(), WsServerEncoder.getName(), wsEncoder); updateHandshakeHeaders(state, acceptGuid, upgradeResponse); return null; } private static void updateHandshakeHeaders(State state, String acceptGuid, HttpServerResponse<?> upgradeResponse) { String acceptSeed = state.getSecWSkey() + acceptGuid; byte[] sha1 = WsUtils.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); String accept = WsUtils.base64(sha1); upgradeResponse.addHeader(SEC_WEBSOCKET_ACCEPT, accept); upgradeResponse.setStatus(HttpResponseStatus.SWITCHING_PROTOCOLS); upgradeResponse.addHeader(HttpHeaderNames.UPGRADE, WEBSOCKET); upgradeResponse.addHeader(CONNECTION, HttpHeaderValues.UPGRADE); if (state.getRequestSubProtocols() != null) { String selectedSubprotocol = WebSocketHandshaker.selectSubprotocol(state.getRequestSubProtocols(), state.getSupportedSubProtocols()); if (selectedSubprotocol != null) { state.getUpgradeResponse().addHeader(SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol); } } } public static class WebSocket7To13UpgradeAcceptedEvent { private final Subscriber<? super Void> acceptUpgradeSubscriber; private final WebSocketHandler handler; private final State state; private final HttpServerRequest<?> request; WebSocket7To13UpgradeAcceptedEvent(Subscriber<? super Void> acceptUpgradeSubscriber, WebSocketHandler handler, State state, HttpServerRequest<?> request) { this.acceptUpgradeSubscriber = acceptUpgradeSubscriber; this.handler = handler; this.state = state; this.request = request; } } }