/* * Copyright 2012 The Netty Project * * The Netty Project licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package io.netty.handler.codec.http.websocketx; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpScheme; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.ThrowableUtil; import java.net.URI; import java.nio.channels.ClosedChannelException; /** * Base class for web socket client handshake implementations */ public abstract class WebSocketClientHandshaker { private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( new ClosedChannelException(), WebSocketClientHandshaker.class, "processHandshake(...)"); private final URI uri; private final WebSocketVersion version; private volatile boolean handshakeComplete; private final String expectedSubprotocol; private volatile String actualSubprotocol; protected final HttpHeaders customHeaders; private final int maxFramePayloadLength; /** * Base constructor * * @param uri * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * sent to this URL. * @param version * Version of web socket specification to use to connect to the server * @param subprotocol * Sub protocol request sent to the server. * @param customHeaders * Map of custom headers to add to the client request * @param maxFramePayloadLength * Maximum length of a frame's payload */ protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, HttpHeaders customHeaders, int maxFramePayloadLength) { this.uri = uri; this.version = version; expectedSubprotocol = subprotocol; this.customHeaders = customHeaders; this.maxFramePayloadLength = maxFramePayloadLength; } /** * Returns the URI to the web socket. e.g. "ws://myhost.com/path" */ public URI uri() { return uri; } /** * Version of the web socket specification that is being used */ public WebSocketVersion version() { return version; } /** * Returns the max length for any frame's payload */ public int maxFramePayloadLength() { return maxFramePayloadLength; } /** * Flag to indicate if the opening handshake is complete */ public boolean isHandshakeComplete() { return handshakeComplete; } private void setHandshakeComplete() { handshakeComplete = true; } /** * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor */ public String expectedSubprotocol() { return expectedSubprotocol; } /** * Returns the subprotocol response sent by the server. Only available after end of handshake. * Null if no subprotocol was requested or confirmed by the server. */ public String actualSubprotocol() { return actualSubprotocol; } private void setActualSubprotocol(String actualSubprotocol) { this.actualSubprotocol = actualSubprotocol; } /** * Begins the opening handshake * * @param channel * Channel */ public ChannelFuture handshake(Channel channel) { if (channel == null) { throw new NullPointerException("channel"); } return handshake(channel, channel.newPromise()); } /** * Begins the opening handshake * * @param channel * Channel * @param promise * the {@link ChannelPromise} to be notified when the opening handshake is sent */ public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) { FullHttpRequest request = newHandshakeRequest(); HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class); if (decoder == null) { HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class); if (codec == null) { promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "a HttpResponseDecoder or HttpClientCodec")); return promise; } } channel.writeAndFlush(request).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) { if (future.isSuccess()) { ChannelPipeline p = future.channel().pipeline(); ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class); if (ctx == null) { ctx = p.context(HttpClientCodec.class); } if (ctx == null) { promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "a HttpRequestEncoder or HttpClientCodec")); return; } p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder()); promise.setSuccess(); } else { promise.setFailure(future.cause()); } } }); return promise; } /** * Returns a new {@link FullHttpRequest) which will be used for the handshake. */ protected abstract FullHttpRequest newHandshakeRequest(); /** * Validates and finishes the opening handshake initiated by {@link #handshake}}. * * @param channel * Channel * @param response * HTTP response containing the closing handshake details */ public final void finishHandshake(Channel channel, FullHttpResponse response) { verify(response); // Verify the subprotocol that we received from the server. // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null; String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : ""; boolean protocolValid = false; if (expectedProtocol.isEmpty() && receivedProtocol == null) { // No subprotocol required and none received protocolValid = true; setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) { // We require a subprotocol and received one -> verify it for (String protocol : expectedProtocol.split(",")) { if (protocol.trim().equals(receivedProtocol)) { protocolValid = true; setActualSubprotocol(receivedProtocol); break; } } } // else mixed cases - which are all errors if (!protocolValid) { throw new WebSocketHandshakeException(String.format( "Invalid subprotocol. Actual: %s. Expected one of: %s", receivedProtocol, expectedSubprotocol)); } setHandshakeComplete(); final ChannelPipeline p = channel.pipeline(); // Remove decompressor from pipeline if its in use HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class); if (decompressor != null) { p.remove(decompressor); } // Remove aggregator if present before HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class); if (aggregator != null) { p.remove(aggregator); } ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); if (ctx == null) { ctx = p.context(HttpClientCodec.class); if (ctx == null) { throw new IllegalStateException("ChannelPipeline does not contain " + "a HttpRequestEncoder or HttpClientCodec"); } final HttpClientCodec codec = (HttpClientCodec) ctx.handler(); // Remove the encoder part of the codec as the user may start writing frames after this method returns. codec.removeOutboundHandler(); p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder()); // Delay the removal of the decoder so the user can setup the pipeline if needed to handle // WebSocketFrame messages. // See https://github.com/netty/netty/issues/4533 channel.eventLoop().execute(new Runnable() { @Override public void run() { p.remove(codec); } }); } else { if (p.get(HttpRequestEncoder.class) != null) { // Remove the encoder part of the codec as the user may start writing frames after this method returns. p.remove(HttpRequestEncoder.class); } final ChannelHandlerContext context = ctx; p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder()); // Delay the removal of the decoder so the user can setup the pipeline if needed to handle // WebSocketFrame messages. // See https://github.com/netty/netty/issues/4533 channel.eventLoop().execute(new Runnable() { @Override public void run() { p.remove(context.handler()); } }); } } /** * Process the opening handshake initiated by {@link #handshake}}. * * @param channel * Channel * @param response * HTTP response containing the closing handshake details * @return future * the {@link ChannelFuture} which is notified once the handshake completes. */ public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) { return processHandshake(channel, response, channel.newPromise()); } /** * Process the opening handshake initiated by {@link #handshake}}. * * @param channel * Channel * @param response * HTTP response containing the closing handshake details * @param promise * the {@link ChannelPromise} to notify once the handshake completes. * @return future * the {@link ChannelFuture} which is notified once the handshake completes. */ public final ChannelFuture processHandshake(final Channel channel, HttpResponse response, final ChannelPromise promise) { if (response instanceof FullHttpResponse) { try { finishHandshake(channel, (FullHttpResponse) response); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); } } else { ChannelPipeline p = channel.pipeline(); ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); if (ctx == null) { ctx = p.context(HttpClientCodec.class); if (ctx == null) { return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "a HttpResponseDecoder or HttpClientCodec")); } } // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more // then enough for the websockets handshake payload. // // TODO: Make handshake work without HttpObjectAggregator at all. String aggregatorName = "httpAggregator"; p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192)); p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() { @Override protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception { // Remove ourself and do the actual handshake ctx.pipeline().remove(this); try { finishHandshake(channel, msg); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { // Remove ourself and fail the handshake promise. ctx.pipeline().remove(this); promise.setFailure(cause); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { // Fail promise if Channel was closed promise.tryFailure(CLOSED_CHANNEL_EXCEPTION); ctx.fireChannelInactive(); } }); try { ctx.fireChannelRead(ReferenceCountUtil.retain(response)); } catch (Throwable cause) { promise.setFailure(cause); } } return promise; } /** * Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong. */ protected abstract void verify(FullHttpResponse response); /** * Returns the decoder to use after handshake is complete. */ protected abstract WebSocketFrameDecoder newWebsocketDecoder(); /** * Returns the encoder to use after the handshake is complete. */ protected abstract WebSocketFrameEncoder newWebSocketEncoder(); /** * Performs the closing handshake * * @param channel * Channel * @param frame * Closing Frame that was received */ public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) { if (channel == null) { throw new NullPointerException("channel"); } return close(channel, frame, channel.newPromise()); } /** * Performs the closing handshake * * @param channel * Channel * @param frame * Closing Frame that was received * @param promise * the {@link ChannelPromise} to be notified when the closing handshake is done */ public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) { if (channel == null) { throw new NullPointerException("channel"); } return channel.writeAndFlush(frame, promise); } /** * Return the constructed raw path for the give {@link URI}. */ static String rawPath(URI wsURL) { String path = wsURL.getRawPath(); String query = wsURL.getRawQuery(); if (query != null && !query.isEmpty()) { path = path + '?' + query; } return path == null || path.isEmpty() ? "/" : path; } static int websocketPort(URI wsURL) { // Format request int wsPort = wsURL.getPort(); // check if the URI contained a port if not set the correct one depending on the schema. // See https://github.com/netty/netty/pull/1558 if (wsPort == -1) { return WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ? WebSocketScheme.WSS.port() : WebSocketScheme.WS.port(); } return wsPort; } static CharSequence websocketHostValue(URI wsURL) { int port = wsURL.getPort(); if (port == -1) { return wsURL.getHost(); } String host = wsURL.getHost(); if (port == HttpScheme.HTTP.port()) { return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme()) || WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ? host : NetUtil.toSocketAddressString(host, port); } if (port == HttpScheme.HTTPS.port()) { return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme()) || WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ? host : NetUtil.toSocketAddressString(host, port); } // if the port is not standard (80/443) its needed to add the port to the header. // See http://tools.ietf.org/html/rfc6454#section-6.2 return NetUtil.toSocketAddressString(host, port); } static CharSequence websocketOriginValue(String host, int wsPort) { String originValue = (wsPort == HttpScheme.HTTPS.port() ? HttpScheme.HTTPS.name() : HttpScheme.HTTP.name()) + "://" + host; if (wsPort != HttpScheme.HTTP.port() && wsPort != HttpScheme.HTTPS.port()) { // if the port is not standard (80/443) its needed to add the port to the header. // See http://tools.ietf.org/html/rfc6454#section-6.2 return NetUtil.toSocketAddressString(originValue, wsPort); } return originValue; } }