/*
* Copyright 2012 Future Systems
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.krakenapps.httpd;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameEncoder;
import org.krakenapps.httpd.impl.WebSocketChannel;
import org.krakenapps.httpd.impl.WebSocketFrameDecoderWithHost;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WebSocketServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private static final int MAX_WEBSOCKET_FRAME_SIZE = 8 * 1024 * 1024;
private final Logger logger = LoggerFactory.getLogger(WebSocketServlet.class.getName());
private WebSocketManager manager;
public WebSocketServlet(WebSocketManager manager) {
this.manager = manager;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
if (!HttpHeaders.Values.UPGRADE.equalsIgnoreCase(req.getHeader(HttpHeaders.Names.CONNECTION))
|| !HttpHeaders.Values.WEBSOCKET.equalsIgnoreCase(req.getHeader(HttpHeaders.Names.UPGRADE)))
return;
try {
// create websocket handshake response
resp.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
resp.addHeader(HttpHeaders.Names.UPGRADE, HttpHeaders.Values.WEBSOCKET);
resp.addHeader(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.UPGRADE);
// fill in the headers and contents depending on handshake method
String key1 = req.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY1);
String key2 = req.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY2);
if (key1 != null && key2 != null) {
// New handshake method with a challenge
resp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_ORIGIN, req.getHeader(HttpHeaders.Names.ORIGIN));
resp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_LOCATION, getWebSocketLocation(req));
String protocol = req.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL);
if (protocol != null) {
resp.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL, protocol);
}
// calculate the answer of the challenge
int a = (int) (Long.parseLong(key1.replaceAll("[^0-9]", "")) / key1.replaceAll("[^ ]", "").length());
int b = (int) (Long.parseLong(key2.replaceAll("[^0-9]", "")) / key2.replaceAll("[^ ]", "").length());
byte[] l = new byte[8];
req.getInputStream().read(l);
ByteBuffer bb = ByteBuffer.wrap(l);
long c = bb.getLong();
ChannelBuffer input = ChannelBuffers.buffer(16);
input.writeInt(a);
input.writeInt(b);
input.writeLong(c);
byte[] output = MessageDigest.getInstance("MD5").digest(input.array());
resp.getOutputStream().write(output);
} else {
// Old handshake method with no challenge
resp.addHeader(HttpHeaders.Names.WEBSOCKET_ORIGIN, req.getHeader(HttpHeaders.Names.ORIGIN));
resp.addHeader(HttpHeaders.Names.WEBSOCKET_LOCATION, getWebSocketLocation(req));
String protocol = req.getHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL);
if (protocol != null) {
resp.addHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL, protocol);
}
}
// upgrade the connection and send the handshake response
String host = req.getHeader(HttpHeaders.Names.HOST);
Channel channel = (Channel) req.getAttribute("netty.channel");
ChannelPipeline p = channel.getPipeline();
p.remove("aggregator");
p.replace("decoder", "wsdecoder", new WebSocketFrameDecoderWithHost(host, MAX_WEBSOCKET_FRAME_SIZE));
resp.getOutputStream().close();
p.replace("encoder", "wsencoder", new WebSocketFrameEncoder());
// open session
WebSocket socket = new WebSocketChannel(channel);
manager.register(socket);
} catch (Throwable t) {
logger.error("kraken httpd: websocket handshake failed", t);
}
}
private String getWebSocketLocation(HttpServletRequest req) {
return "ws://" + req.getHeader(HttpHeaders.Names.HOST) + manager.getPath();
}
}