// ---------------------------------------------------------------------------
// jWebSocket - Copyright (c) 2010 Innotrade GmbH, jWebSocket.org
// ---------------------------------------------------------------------------
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU Lesser General Public License as published by the
// Free Software Foundation; either version 3 of the License, or (at your
// option) any later version.
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
// more details.
// You should have received a copy of the GNU Lesser General Public License along
// with this program; if not, see <http://www.gnu.org/licenses/lgpl.html>.
// ---------------------------------------------------------------------------
package org.jwebsocket.netty.engines;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Map;
import javolution.util.FastMap;
import org.apache.log4j.Logger;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.ChannelEvent;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.ChildChannelStateEvent;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.WriteCompletionEvent;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrame;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameDecoder;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameEncoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.util.CharsetUtil;
import org.jwebsocket.api.EngineConfiguration;
import org.jwebsocket.api.WebSocketConnector;
import org.jwebsocket.config.JWebSocketCommonConstants;
import org.jwebsocket.kit.CloseReason;
import org.jwebsocket.kit.RawPacket;
import org.jwebsocket.kit.RequestHeader;
import org.jwebsocket.kit.WebSocketRuntimeException;
import org.jwebsocket.logging.Logging;
import org.jwebsocket.netty.connectors.NettyConnector;
import org.jwebsocket.netty.http.HttpHeaders;
/**
* Handler class for the <tt>NettyEngine</tt> that recieves the events based on
* event types and notifies the client connectors. This handler also handles the
* initial handshaking for WebSocket connection with a appropriate hand shake
* response. This handler is created for each new connection channel.
* <p>
* Once the handshaking is successful after sending the handshake {@code
* HttpResponse} it replaces the {@code HttpRequestDecoder} and {@code
* HttpResponseEncoder} from the channel pipeline with {@code
* WebSocketFrameDecoder} as WebSocket frame data decoder and {@code
* WebSocketFrameEncoder} as WebSocket frame data encoder. Also it starts the
* <tt>NettyConnector</tt>.
* </p>
*
* @author <a href="http://www.purans.net/">Puran Singh</a>
* @version $Id: NettyEngineHandler.java 613 2010-07-01 07:13:29Z mailtopuran@gmail.com $
*/
public class NettyEngineHandler extends SimpleChannelUpstreamHandler {
private static final Logger mLog = Logging.getLogger(NettyEngineHandler.class);
private NettyEngine mEngine = null;
private WebSocketConnector mConnector = null;
private ChannelHandlerContext mContext = null;
private static final ChannelGroup mChannels = new DefaultChannelGroup();
private static final String CONTENT_LENGTH = "Content-Length";
/* Removed by Alex because these constants now are maintained in RequestHeader
private static final String ARGS = "args";
private static final String ORIGIN = "origin";
private static final String LOCATION = "location";
private static final String PATH = "path";
private static final String SEARCH_STRING = "searchString";
private static final String HOST = "host";
*/
public NettyEngineHandler(NettyEngine aEngine) {
this.mEngine = aEngine;
}
/**
* {@inheritDoc }
*/
@Override
public void channelBound(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.channelBound(aCtx, aEvent);
}
/**
* {@inheritDoc }
*/
@Override
public void channelClosed(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.channelClosed(aCtx, aEvent);
}
/**
* {@inheritDoc }
*/
@Override
public void channelConnected(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
// Get the SslHandler in the current pipeline.
final SslHandler sslHandler = aCtx.getPipeline().get(SslHandler.class);
// Get notified when SSL handshake is done.
// Added by Alex to prevent exceptions
// TODO: Fix this exceptions on connect!
// ADD-START
if (sslHandler != null)
// ADD-END
{
try {
ChannelFuture lHandshakeFuture = sslHandler.handshake();
lHandshakeFuture.addListener(new SecureWebSocketConnectionListener(sslHandler));
} catch (Exception es) {
es.printStackTrace();
}
}
}
/**
* {@inheritDoc }
*/
@Override
public void channelDisconnected(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
if (mLog.isDebugEnabled()) {
mLog.debug("Channel is disconnected");
}
// remove the channel
mChannels.remove(aEvent.getChannel());
this.mContext = aCtx;
super.channelDisconnected(aCtx, aEvent);
mEngine.connectorStopped(mConnector, CloseReason.CLIENT);
}
/**
* {@inheritDoc }
*/
@Override
public void channelInterestChanged(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.channelInterestChanged(aCtx, aEvent);
}
/**
* {@inheritDoc}
*/
@Override
public void channelOpen(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.channelOpen(aCtx, aEvent);
}
/**
* {@inheritDoc}
*/
@Override
public void channelUnbound(ChannelHandlerContext aCtx, ChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.channelUnbound(aCtx, aEvent);
}
/**
* {@inheritDoc}
*/
@Override
public void childChannelClosed(ChannelHandlerContext aCtx, ChildChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.childChannelClosed(aCtx, aEvent);
}
/**
* {@inheritDoc}
*/
@Override
public void childChannelOpen(ChannelHandlerContext aCtx, ChildChannelStateEvent aEvent) throws Exception {
this.mContext = aCtx;
super.childChannelOpen(aCtx, aEvent);
}
/**
* {@inheritDoc}
*/
@Override
public void exceptionCaught(ChannelHandlerContext aCtx, ExceptionEvent aEvent) throws Exception {
this.mContext = aCtx;
if (mLog.isDebugEnabled()) {
mLog.debug("Channel is disconnected:" + aEvent.getCause().getLocalizedMessage());
}
}
/**
* {@inheritDoc}
*/
@Override
public void handleUpstream(ChannelHandlerContext aCtx, ChannelEvent aEvent) throws Exception {
this.mContext = aCtx;
super.handleUpstream(aCtx, aEvent);
}
/**
* {@inheritDoc}
*/
@Override
public void messageReceived(ChannelHandlerContext aCtx, MessageEvent aEvent) throws Exception {
this.mContext = aCtx;
if (mLog.isDebugEnabled()) {
mLog.debug("message received in the engine handler");
}
Object lMsg = aEvent.getMessage();
if (lMsg instanceof HttpRequest) {
handleHttpRequest(aCtx, (HttpRequest) lMsg);
} else if (lMsg instanceof WebSocketFrame) {
handleWebSocketFrame(aCtx, (WebSocketFrame) lMsg);
}
}
/**
* private method that sends the handshake response for WebSocket connection
*
* @param aCtx the channel context
* @param aReq http request object
* @param aResp http response object
*/
private void sendHttpResponse(ChannelHandlerContext aCtx, HttpRequest aReq, HttpResponse aResp) {
// Generate an error page if response status code is not OK (200).
if (aResp.getStatus().getCode() != 200) {
aResp.setContent(ChannelBuffers.copiedBuffer(aResp.getStatus().toString(), CharsetUtil.UTF_8));
setContentLength(aResp, aResp.getContent().readableBytes());
}
// Send the response and close the connection if necessary.
ChannelFuture lCF = aCtx.getChannel().write(aResp);
if (!isKeepAlive(aReq) || aResp.getStatus().getCode() != 200) {
lCF.addListener(ChannelFutureListener.CLOSE);
}
}
/**
* Check if the request header has Keep-Alive
*
* @param aReq the http request object
* @return {@code true} if keep-alive is set in the header {@code false}
* otherwise
*/
private boolean isKeepAlive(HttpRequest aReq) {
String lKeepAlive = aReq.getHeader(HttpHeaders.Values.KEEP_ALIVE);
if (lKeepAlive != null && lKeepAlive.length() > 0) {
return true;
} else {
return false;
}
}
/**
* Set the content length in the response
*
* @param res the http response object
* @param aReadableBytes the length of the bytes
*/
private void setContentLength(HttpResponse aResp, int aReadableBytes) {
aResp.setHeader(CONTENT_LENGTH, aReadableBytes);
}
/**
* private method that handles the web socket frame data, this method is
* used only after the WebSocket connection is established.
*
* @param aCtx the channel handler context
* @param aMsg the web socket frame data
*/
private void handleWebSocketFrame(ChannelHandlerContext aCtx, WebSocketFrame aMsg) throws WebSocketRuntimeException {
String lTextData = "";
if (aMsg.isBinary()) {
// TODO: handle binary data
} else if (aMsg.isText()) {
lTextData = aMsg.getTextData();
} else {
throw new WebSocketRuntimeException("Frame Doesn't contain any type of data");
}
mEngine.processPacket(mConnector, new RawPacket(lTextData));
}
/**
* Handles the initial HTTP request for handshaking if the http request
* contains Upgrade header value as WebSocket then this method sends the
* handshake response and also fires the events on client connector.
*
* @param aCtx the channel handler context
* @param req the request message
*/
private void handleHttpRequest(ChannelHandlerContext aCtx, HttpRequest aReq) {
// Allow only GET methods.
if (aReq.getMethod() != HttpMethod.GET) {
sendHttpResponse(aCtx, aReq, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN));
return;
}
// Serve the WebSocket handshake request.
if (HttpHeaders.Values.UPGRADE.equalsIgnoreCase(aReq.getHeader(HttpHeaders.Names.CONNECTION))
&& HttpHeaders.Values.WEBSOCKET.equalsIgnoreCase(aReq.getHeader(HttpHeaders.Names.UPGRADE))) {
// Create the WebSocket handshake response.
HttpResponse lResp = null;
try {
lResp = constructHandShakeResponse(aReq, aCtx);
} catch (NoSuchAlgorithmException lNSAEx) {
// better to close the channel
mLog.debug("Channel is disconnected");
aCtx.getChannel().close();
}
// write the response
aCtx.getChannel().write(lResp);
mChannels.add(aCtx.getChannel());
// since handshaking is done, replace the encoder/decoder with
// web socket data frame encoder/decoder
ChannelPipeline lPipeline = aCtx.getChannel().getPipeline();
lPipeline.remove("aggregator");
EngineConfiguration lConfig = mEngine.getConfiguration();
if (lConfig == null || lConfig.getMaxFramesize() == 0) {
lPipeline.replace("decoder", "jwsdecoder", new WebSocketFrameDecoder(JWebSocketCommonConstants.DEFAULT_MAX_FRAME_SIZE));
} else {
lPipeline.replace("decoder", "jwsdecoder", new WebSocketFrameDecoder(lConfig.getMaxFramesize()));
}
lPipeline.replace("encoder", "jwsencoder", new WebSocketFrameEncoder());
//if the WebSocket connection URI is wss then start SSL TLS handshaking
if (aReq.getUri().startsWith("wss:")) {
// Get the SslHandler in the current pipeline.
final SslHandler sslHandler = aCtx.getPipeline().get(SslHandler.class);
// Get notified when SSL handshake is done.
ChannelFuture lHandshakeFuture = sslHandler.handshake();
lHandshakeFuture.addListener(new SecureWebSocketConnectionListener(sslHandler));
}
// initialize the connector
mConnector = initializeConnector(aCtx, aReq);
return;
}
// Send an error page otherwise.
sendHttpResponse(aCtx, aReq, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN));
}
/**
* Constructs the <tt>HttpResponse</tt> object for the handshake response
*
* @param aReq the http request object
* @param aCtx the channel handler context
* @return the http handshake response
* @throws NoSuchAlgorithmException
*/
private HttpResponse constructHandShakeResponse(HttpRequest aReq, ChannelHandlerContext aCtx) throws NoSuchAlgorithmException {
HttpResponse lResp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, new HttpResponseStatus(101, "Web Socket Protocol Handshake"));
lResp.addHeader(HttpHeaders.Names.UPGRADE, HttpHeaders.Values.WEBSOCKET);
lResp.addHeader(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.UPGRADE);
// Fill in the headers and contents depending on handshake method.
if (aReq.containsHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY1) && aReq.containsHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY2)) {
// New handshake method with a challenge:
lResp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_ORIGIN, aReq.getHeader(HttpHeaders.Names.ORIGIN));
lResp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_LOCATION, getWebSocketLocation(aReq));
String lProtocol = aReq.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL);
// Added by Alex 2010-10-25:
// fallback for FlashBridge (which sends "WebSocket-Protocol"
// instead of "Sec-WebSocket-Protocol"
if (lProtocol != null) {
lResp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL, lProtocol);
} else {
lProtocol = aReq.getHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL);
if (lProtocol != null) {
lResp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL, lProtocol);
}
}
// Calculate the answer of the challenge.
String lKey1 = aReq.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY1);
String lKey2 = aReq.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY2);
int lA = (int) (Long.parseLong(lKey1.replaceAll("[^0-9]", "")) / lKey1.replaceAll("[^ ]", "").length());
int lB = (int) (Long.parseLong(lKey2.replaceAll("[^0-9]", "")) / lKey2.replaceAll("[^ ]", "").length());
long lC = aReq.getContent().readLong();
ChannelBuffer lInput = ChannelBuffers.buffer(16);
lInput.writeInt(lA);
lInput.writeInt(lB);
lInput.writeLong(lC);
ChannelBuffer lOutput = ChannelBuffers.wrappedBuffer(MessageDigest.getInstance("MD5").digest(lInput.array()));
lResp.setContent(lOutput);
} else {
// Old handshake method with no challenge:
lResp.addHeader(HttpHeaders.Names.WEBSOCKET_ORIGIN, aReq.getHeader(HttpHeaders.Names.ORIGIN));
lResp.addHeader(HttpHeaders.Names.WEBSOCKET_LOCATION, getWebSocketLocation(aReq));
String lProtocol = aReq.getHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL);
if (lProtocol != null) {
lResp.addHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL, lProtocol);
}
}
return lResp;
}
/**
* Initialize the {@code NettyConnector} after initial handshaking is
* successfull.
*
* @param aCtx the channel handler context
* @param req the http request object
*/
private WebSocketConnector initializeConnector(ChannelHandlerContext aCtx, HttpRequest aReq) {
RequestHeader lHeader = getRequestHeader(aReq);
int lSessionTimeout = lHeader.getTimeout(JWebSocketCommonConstants.DEFAULT_TIMEOUT);
if (lSessionTimeout > 0) {
aCtx.getChannel().getConfig().setConnectTimeoutMillis(lSessionTimeout);
}
// create connector
WebSocketConnector lConnector = new NettyConnector(mEngine, this);
lConnector.setHeader(lHeader);
mEngine.getConnectors().put(lConnector.getId(), lConnector);
lConnector.startConnector();
// allow descendant classes to handle connector started event
mEngine.connectorStarted(lConnector);
return lConnector;
}
/**
* Construct the request header to save it in the connector
*
* @param aReq the http request header
* @return the request header
*/
private RequestHeader getRequestHeader(HttpRequest aReq) {
RequestHeader lHeader = new RequestHeader();
Map<String, String> lArgs = new FastMap<String, String>();
String lSearchString = "";
String lPath = aReq.getUri();
// isolate search string
int lPos = lPath.indexOf(JWebSocketCommonConstants.PATHARG_SEPARATOR);
if (lPos >= 0) {
lSearchString = lPath.substring(lPos + 1);
if (lSearchString.length() > 0) {
String[] lKeyValPairs = lSearchString.split(JWebSocketCommonConstants.ARGARG_SEPARATOR);
for (int lIdx = 0; lIdx < lKeyValPairs.length; lIdx++) {
String[] lKeyVal = lKeyValPairs[lIdx].split(JWebSocketCommonConstants.KEYVAL_SEPARATOR, 2);
if (lKeyVal.length == 2) {
lArgs.put(lKeyVal[0], lKeyVal[1]);
if (mLog.isDebugEnabled()) {
mLog.debug("arg" + lIdx + ": " + lKeyVal[0] + "=" + lKeyVal[1]);
}
}
}
}
}
// set default sub protocol if none passed
// if no sub protocol given in request header,
// try to get it from arguments
String lSubProt = aReq.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL);
if (lSubProt == null) {
lSubProt = lArgs.get(RequestHeader.WS_PROTOCOL);
}
if (lSubProt == null) {
lSubProt = JWebSocketCommonConstants.WS_SUBPROT_DEFAULT;
}
lHeader.put(RequestHeader.URL_ARGS, lArgs);
lHeader.put(RequestHeader.WS_ORIGIN, aReq.getHeader(HttpHeaders.Names.ORIGIN));
lHeader.put(RequestHeader.WS_LOCATION, getWebSocketLocation(aReq));
lHeader.put(RequestHeader.WS_PATH, aReq.getUri());
lHeader.put(RequestHeader.WS_PROTOCOL, lSubProt);
lHeader.put(RequestHeader.WS_SEARCHSTRING, lSearchString);
lHeader.put(RequestHeader.WS_HOST, aReq.getHeader(HttpHeaders.Names.HOST));
return lHeader;
}
/**
* {@inheritDoc}
*/
@Override
public void writeComplete(ChannelHandlerContext aCtx, WriteCompletionEvent aEvent) throws Exception {
super.writeComplete(aCtx, aEvent);
}
/**
* Returns the web socket location URL
*
* @param aReq the http request object
* @return the location url string
*/
private String getWebSocketLocation(HttpRequest aReq) {
//TODO: fix this URL for wss: (secure)
String location = "ws://" + aReq.getHeader(HttpHeaders.Names.HOST) + aReq.getUri();
return location;
}
/**
* Returns the channel context
*
* @return the channel context
*/
public ChannelHandlerContext getChannelHandlerContext() {
return mContext;
}
/**
* Listener class for SSL TLS handshake completion.
*/
private static final class SecureWebSocketConnectionListener implements ChannelFutureListener {
private final SslHandler mSSLHandler;
SecureWebSocketConnectionListener(SslHandler aSSLHandler) {
this.mSSLHandler = aSSLHandler;
}
@Override
public void operationComplete(ChannelFuture aFuture) throws Exception {
if (aFuture.isSuccess()) {
// that means SSL handshaking is done.
if (mLog.isInfoEnabled()) {
mLog.info("SSL handshaking success");
}
} else {
aFuture.getChannel().close();
}
}
}
}