/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package org.jboss.netty.handler.codec.http.websocketx;
import static org.jboss.netty.handler.codec.http.HttpHeaders.isKeepAlive;
import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.logging.InternalLogger;
import org.jboss.netty.logging.InternalLoggerFactory;
/**
* Handles the HTTP handshake (the HTTP Upgrade request)
*/
public class WebSocketServerProtocolHandshakeHandler extends SimpleChannelUpstreamHandler {
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(WebSocketServerProtocolHandshakeHandler.class);
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
public WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
}
@Override
public void messageReceived(final ChannelHandlerContext ctx, MessageEvent e) throws Exception {
if (e.getMessage() instanceof HttpRequest) {
HttpRequest req = (HttpRequest) e.getMessage();
if (req.getMethod() != GET) {
sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.getPipeline(), req, websocketPath), subprotocols, allowExtensions);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel());
} else {
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.getChannel(), req);
handshakeFuture.addListener(new ChannelFutureListener() {
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
Channels.fireExceptionCaught(ctx, future.getCause());
}
}
});
WebSocketServerProtocolHandler.setHandshaker(ctx, handshaker);
ctx.getPipeline().replace(this, "WS403Responder",
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
}
}
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("Exception Caught", cause);
ctx.getChannel().close();
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.getChannel().write(res);
if (!isKeepAlive(req) || res.getStatus().getCode() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
return protocol + "://" + req.headers().get(HttpHeaders.Names.HOST) + path;
}
}