/*
* JBoss, Home of Professional Open Source
* Copyright 2010 Red Hat Inc. and/or its affiliates and other
* contributors as indicated by the @author tags. All rights reserved.
* See the copyright.txt in the distribution for a full listing of
* individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.infinispan.server.websocket;
import static org.jboss.netty.handler.codec.http.HttpHeaders.*;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.*;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Values.*;
import static org.jboss.netty.handler.codec.http.HttpMethod.*;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.*;
import static org.jboss.netty.handler.codec.http.HttpVersion.*;
import java.io.StringWriter;
import java.security.MessageDigest;
import java.util.Map;
import org.infinispan.Cache;
import org.infinispan.manager.CacheContainer;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpHeaders.Names;
import org.jboss.netty.handler.codec.http.HttpHeaders.Values;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrame;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameDecoder;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameEncoder;
import org.jboss.netty.util.CharsetUtil;
import org.json.JSONException;
import org.json.JSONObject;
/**
* Web Socket Server Handler (Netty).
* <p/>
* Websocket specific code lifted from Netty WebSocket Server example.
*/
public class WebSocketServerHandler extends SimpleChannelUpstreamHandler {
private static final String INFINISPAN_WS_JS_FILENAME = "infinispan-ws.js";
private CacheContainer cacheContainer;
private Map<String, OpHandler> operationHandlers;
private boolean connectionUpgraded;
private Map<String, Cache> startedCaches;
public WebSocketServerHandler(CacheContainer cacheContainer, Map<String, OpHandler> operationHandlers, Map<String, Cache> startedCaches) {
this.cacheContainer = cacheContainer;
this.operationHandlers = operationHandlers;
this.startedCaches = startedCaches;
}
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
Object msg = e.getMessage();
if (msg instanceof HttpRequest) {
handleHttpRequest(ctx, (HttpRequest) msg);
} else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) throws Exception {
// Allow only GET methods.
if (req.getMethod() != GET) {
sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
if (!connectionUpgraded && req.getUri().equalsIgnoreCase("/" + INFINISPAN_WS_JS_FILENAME)) {
DefaultHttpResponse res = new DefaultHttpResponse(HTTP_1_1, OK);
loadScriptToResponse(req, res);
sendHttpResponse(ctx, req, res);
return;
} else if (Values.UPGRADE.equalsIgnoreCase(req.getHeader(CONNECTION)) &&
WEBSOCKET.equalsIgnoreCase(req.getHeader(Names.UPGRADE))) {
// Serve the WebSocket handshake request.
// Create the WebSocket handshake response.
HttpResponse res = new DefaultHttpResponse(HTTP_1_1, new HttpResponseStatus(101, "Web Socket Protocol Handshake"));
res.addHeader(Names.UPGRADE, Values.WEBSOCKET);
res.addHeader(Names.CONNECTION, Values.UPGRADE);
// Fill in the headers and contents depending on handshake method.
if (req.containsHeader(Names.SEC_WEBSOCKET_KEY1) &&
req.containsHeader(Names.SEC_WEBSOCKET_KEY2)) {
// New handshake method with a challenge:
res.addHeader(Names.SEC_WEBSOCKET_ORIGIN, req.getHeader(Names.ORIGIN));
res.addHeader(Names.SEC_WEBSOCKET_LOCATION, getWebSocketLocation(req));
String protocol = req.getHeader(Names.SEC_WEBSOCKET_PROTOCOL);
if (protocol != null) {
res.addHeader(Names.SEC_WEBSOCKET_PROTOCOL, protocol);
}
// Calculate the answer of the challenge.
String key1 = req.getHeader(Names.SEC_WEBSOCKET_KEY1);
String key2 = req.getHeader(Names.SEC_WEBSOCKET_KEY2);
int a = (int) (Long.parseLong(key1.replaceAll("[^0-9]", "")) / key1.replaceAll("[^ ]", "").length());
int b = (int) (Long.parseLong(key2.replaceAll("[^0-9]", "")) / key2.replaceAll("[^ ]", "").length());
long c = req.getContent().readLong();
ChannelBuffer input = ChannelBuffers.buffer(16);
input.writeInt(a);
input.writeInt(b);
input.writeLong(c);
ChannelBuffer output = ChannelBuffers.wrappedBuffer(
MessageDigest.getInstance("MD5").digest(input.array()));
res.setContent(output);
} else {
// Old handshake method with no challenge:
res.addHeader(Names.WEBSOCKET_ORIGIN, req.getHeader(Names.ORIGIN));
res.addHeader(Names.WEBSOCKET_LOCATION, getWebSocketLocation(req));
String protocol = req.getHeader(Names.WEBSOCKET_PROTOCOL);
if (protocol != null) {
res.addHeader(Names.WEBSOCKET_PROTOCOL, protocol);
}
}
// Upgrade the connection and send the handshake response.
ChannelPipeline p = ctx.getChannel().getPipeline();
p.remove("aggregator");
p.replace("decoder", "wsdecoder", new WebSocketFrameDecoder());
ctx.getChannel().write(res);
p.replace("encoder", "wsencoder", new WebSocketFrameEncoder());
return;
}
// Send an error page otherwise.
sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, FORBIDDEN));
}
private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
try {
JSONObject payload = new JSONObject(frame.getTextData());
String opCode = (String) payload.get(OpHandler.OP_CODE);
String cacheName = (String) payload.opt(OpHandler.CACHE_NAME);
Cache<Object, Object> cache = getCache(cacheName);
OpHandler handler = operationHandlers.get(opCode);
if (handler != null) {
handler.handleOp(payload, cache, ctx);
}
} catch (JSONException e) {
}
}
private Cache<Object, Object> getCache(final String cacheName) {
String key = cacheName;
Cache<Object, Object> cache;
if (key == null) {
key = "";
}
cache = startedCaches.get(key);
if (cache == null) {
synchronized (startedCaches) {
cache = startedCaches.get(key);
if (cache == null) {
if (cacheName != null) {
cache = cacheContainer.getCache(key);
} else {
cache = cacheContainer.getCache();
}
startedCaches.put(key, cache);
cache.start();
}
}
}
return cache;
}
private void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
// Generate an error page if response status code is not OK (200).
if (res.getStatus().getCode() != 200) {
res.setContent(ChannelBuffers.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8));
HttpHeaders.setContentLength(res, res.getContent().readableBytes());
}
// Send the response and close the connection if necessary.
ChannelFuture f = ctx.getChannel().write(res);
if (!isKeepAlive(req) || res.getStatus().getCode() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private void loadScriptToResponse(HttpRequest req, DefaultHttpResponse res) {
String wsAddress = getWebSocketLocation(req);
StringWriter writer = new StringWriter();
writer.write("var defaultWSAddress = '" + wsAddress + "';");
writer.write(WebSocketServer.getJavascript());
ChannelBuffer content = ChannelBuffers.copiedBuffer(writer.toString(), CharsetUtil.UTF_8);
res.setHeader(CONTENT_TYPE, "text/javascript; charset=UTF-8");
setContentLength(res, content.readableBytes());
res.setContent(content);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
e.getCause().printStackTrace();
e.getChannel().close();
}
private String getWebSocketLocation(HttpRequest req) {
return "ws://" + req.getHeader(HttpHeaders.Names.HOST) + "/";
}
}