package com.openrobot.common;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Iterator;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.LinkedBlockingQueue;
/**
* <tt>WebSocketServer</tt> is an abstract class that only takes care of the
* HTTP handshake portion of WebSockets. It's up to a subclass to add
* functionality/purpose to the server.
* @author Nathan Rajlich
*/
public abstract class WebSocketServer implements Runnable, WebSocketListener {
// CONSTANTS ///////////////////////////////////////////////////////////////
/**
* The value of <var>handshake</var> when a Flash client requests a policy
* file on this server.
*/
private static final String FLASH_POLICY_REQUEST = "<policy-file-request/>\0";
// INSTANCE PROPERTIES /////////////////////////////////////////////////////
/**
* Holds the list of active WebSocket connections. "Active" means WebSocket
* handshake is complete and socket can be written to, or read from.
*/
private final CopyOnWriteArraySet<WebSocket> connections;
/**
* The port number that this WebSocket server should listen on. Default is
* WebSocket.DEFAULT_PORT.
*/
private int port;
/**
* The socket channel for this WebSocket server.
*/
private ServerSocketChannel server;
/**
* The 'Selector' used to get event keys from the underlying socket.
*/
private Selector selector;
/**
* The Draft of the WebSocket protocol the Server is adhering to.
*/
private WebSocketDraft draft;
// CONSTRUCTORS ////////////////////////////////////////////////////////////
/**
* Nullary constructor. Creates a WebSocketServer that will attempt to
* listen on port WebSocket.DEFAULT_PORT.
*/
public WebSocketServer() {
this(WebSocket.DEFAULT_PORT, WebSocketDraft.AUTO);
}
/**
* Creates a WebSocketServer that will attempt to listen on port
* <var>port</var>.
* @param port The port number this server should listen on.
*/
public WebSocketServer(int port) {
this(port, WebSocketDraft.AUTO);
}
/**
* Creates a WebSocketServer that will attempt to listen on port <var>port</var>,
* and comply with <tt>WebSocketDraft</tt> version <var>draft</var>.
* @param port The port number this server should listen on.
* @param draft The version of the WebSocket protocol that this server
* instance should comply to.
*/
public WebSocketServer(int port, WebSocketDraft draft) {
this.connections = new CopyOnWriteArraySet<WebSocket>();
this.draft = draft;
setPort(port);
}
/**
* Starts the server thread that binds to the currently set port number and
* listeners for WebSocket connection requests.
*/
public void start() {
(new Thread(this)).start();
}
/**
* Closes all connected clients sockets, then closes the underlying
* ServerSocketChannel, effectively killing the server socket thread and
* freeing the port the server was bound to.
* @throws IOException When socket related I/O errors occur.
*/
public void stop() throws IOException {
for (WebSocket ws : connections) {
ws.close();
}
this.server.close();
}
/**
* Sends <var>text</var> to all currently connected WebSocket clients.
* @param text The String to send across the network.
* @throws IOException When socket related I/O errors occur.
*/
public void sendToAll(String text) throws IOException {
for (WebSocket c : this.connections) {
c.send(text);
}
}
/**
* Sends <var>text</var> to all currently connected WebSocket clients,
* except for the specified <var>connection</var>.
* @param connection The {@link WebSocket} connection to ignore.
* @param text The String to send to every connection except <var>connection</var>.
* @throws IOException When socket related I/O errors occur.
*/
public void sendToAllExcept(WebSocket connection, String text) throws IOException {
if (connection == null) {
throw new NullPointerException("'connection' cannot be null");
}
for (WebSocket c : this.connections) {
if (!connection.equals(c)) {
c.send(text);
}
}
}
/**
* Sends <var>text</var> to all currently connected WebSocket clients,
* except for those found in the Set <var>connections</var>.
* @param connections
* @param text
* @throws IOException When socket related I/O errors occur.
*/
public void sendToAllExcept(Set<WebSocket> connections, String text) throws IOException {
if (connections == null) {
throw new NullPointerException("'connections' cannot be null");
}
for (WebSocket c : this.connections) {
if (!connections.contains(c)) {
c.send(text);
}
}
}
/**
* Returns a WebSocket[] of currently connected clients.
* @return The currently connected clients in a WebSocket[].
*/
public WebSocket[] connections() {
return this.connections.toArray(new WebSocket[0]);
}
/**
* Sets the port that this WebSocketServer should listen on.
* @param port The port number to listen on.
*/
public void setPort(int port) {
this.port = port;
}
/**
* Gets the port number that this server listens on.
* @return The port number.
*/
public int getPort() {
return this.port;
}
public WebSocketDraft getDraft() {
return this.draft;
}
// Runnable IMPLEMENTATION /////////////////////////////////////////////////
public void run() {
try {
server = ServerSocketChannel.open();
server.configureBlocking(false);
server.socket().bind(new java.net.InetSocketAddress(port));
selector = Selector.open();
server.register(selector, server.validOps());
} catch (IOException ex) {
ex.printStackTrace();
return;
}
while(true) {
try {
selector.select();
Set<SelectionKey> keys = selector.selectedKeys();
Iterator<SelectionKey> i = keys.iterator();
while(i.hasNext()) {
SelectionKey key = i.next();
// Remove the current key
i.remove();
// if isAcceptable == true
// then a client required a connection
if (key.isAcceptable()) {
SocketChannel client = server.accept();
client.configureBlocking(false);
WebSocket c = new WebSocket(client, new LinkedBlockingQueue<ByteBuffer>(), this);
client.register(selector, SelectionKey.OP_READ, c);
}
// if isReadable == true
// then the server is ready to read
if (key.isReadable()) {
WebSocket conn = (WebSocket)key.attachment();
conn.handleRead();
}
// if isWritable == true
// then we need to send the rest of the data to the client
if (key.isValid() && key.isWritable()) {
WebSocket conn = (WebSocket)key.attachment();
if (conn.handleWrite()) {
conn.socketChannel().register(selector,
SelectionKey.OP_READ, conn);
}
}
}
for (WebSocket conn : this.connections) {
// We have to do this check here, and not in the thread that
// adds the buffered data to the WebSocket, because the
// Selector is not thread-safe, and can only be accessed
// by this thread.
if (conn.hasBufferedData()) {
conn.socketChannel().register(selector,
SelectionKey.OP_READ | SelectionKey.OP_WRITE, conn);
}
}
} catch (IOException ex) {
ex.printStackTrace();
} catch (RuntimeException ex) {
ex.printStackTrace();
} catch (NoSuchAlgorithmException ex) {
ex.printStackTrace();
}
}
//System.err.println("WebSocketServer thread ended!");
}
/**
* Gets the XML string that should be returned if a client requests a Flash
* security policy.
*
* The default implementation allows access from all remote domains, but
* only on the port that this WebSocketServer is listening on.
*
* This is specifically implemented for gitime's WebSocket client for Flash:
* http://github.com/gimite/web-socket-js
*
* @return An XML String that comforms to Flash's security policy. You MUST
* not include the null char at the end, it is appended automatically.
*/
protected String getFlashSecurityPolicy() {
return "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\""
+ getPort() + "\" /></cross-domain-policy>";
}
// WebSocketListener IMPLEMENTATION ////////////////////////////////////////
/**
* Called by a {@link WebSocket} instance when a client connection has
* finished sending a handshake. This method verifies that the handshake is
* a valid WebSocket cliend request. Then sends a WebSocket server handshake
* if it is valid, or closes the connection if it is not.
* @param conn The {@link WebSocket} instance who's handshake has been recieved.
* @param handshake The entire UTF-8 decoded handshake from the connection.
* @return True if the client sent a valid WebSocket handshake and this server
* successfully sent a WebSocket server handshake, false otherwise.
* @throws IOException When socket related I/O errors occur.
* @throws NoSuchAlgorithmException
*/
public boolean onHandshakeRecieved(WebSocket conn, String handshake, byte[] key3) throws IOException, NoSuchAlgorithmException {
// If a Flash client requested the Policy File...
if (FLASH_POLICY_REQUEST.equals(handshake)) {
String policy = getFlashSecurityPolicy() + "\0";
conn.socketChannel().write(ByteBuffer.wrap(policy.getBytes(WebSocket.UTF8_CHARSET.toString())));
return false;
}
String[] requestLines = handshake.split("\r\n");
boolean isWebSocketRequest = true;
String line = requestLines[0].trim();
String path = null;
if (!(line.startsWith("GET") && line.endsWith("HTTP/1.1"))) {
isWebSocketRequest = false;
} else {
String[] firstLineTokens = line.split(" ");
path = firstLineTokens[1];
}
// 'p' will hold the HTTP headers
Properties p = new Properties();
for (int i = 1; i < requestLines.length; i++) {
line = requestLines[i];
int firstColon = line.indexOf(":");
if (firstColon != -1) {
p.setProperty(line.substring(0, firstColon).trim(), line.substring(firstColon+1).trim());
}
}
String prop = p.getProperty("Upgrade");
if (prop == null || !prop.equals("WebSocket")) {
isWebSocketRequest = false;
}
prop = p.getProperty("Connection");
if (prop == null || !prop.equals("Upgrade")) {
isWebSocketRequest = false;
}
String key1 = p.getProperty("Sec-WebSocket-Key1");
String key2 = p.getProperty("Sec-WebSocket-Key2");
String headerPrefix = "";
byte[] responseChallenge = null;
switch (this.draft) {
case DRAFT75:
if (key1 != null || key2 != null || key3 != null) {
isWebSocketRequest = false;
}
break;
case DRAFT76:
if (key1 == null || key2 == null || key3 == null) {
isWebSocketRequest = false;
}
break;
}
if (isWebSocketRequest) {
if (key1 != null && key2 != null && key3 != null) {
headerPrefix = "Sec-";
byte[] part1 = this.getPart(key1);
byte[] part2 = this.getPart(key2);
byte[] challenge = new byte[16];
challenge[0] = part1[0];
challenge[1] = part1[1];
challenge[2] = part1[2];
challenge[3] = part1[3];
challenge[4] = part2[0];
challenge[5] = part2[1];
challenge[6] = part2[2];
challenge[7] = part2[3];
challenge[8] = key3[0];
challenge[9] = key3[1];
challenge[10] = key3[2];
challenge[11] = key3[3];
challenge[12] = key3[4];
challenge[13] = key3[5];
challenge[14] = key3[6];
challenge[15] = key3[7];
MessageDigest md5 = MessageDigest.getInstance("MD5");
responseChallenge = md5.digest(challenge);
}
String responseHandshake = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" +
"Upgrade: WebSocket\r\n" +
"Connection: Upgrade\r\n";
responseHandshake += headerPrefix+"WebSocket-Origin: " + p.getProperty("Origin") + "\r\n";
responseHandshake += headerPrefix+"WebSocket-Location: ws://" + p.getProperty("Host") + path + "\r\n";
if (p.containsKey(headerPrefix+"WebSocket-Protocol")) {
responseHandshake += headerPrefix+"WebSocket-Protocol: " + p.getProperty("WebSocket-Protocol") + "\r\n";
}
if (p.containsKey("Cookie")){
responseHandshake += "Cookie: " + p.getProperty("Cookie")+"\r\n";
}
responseHandshake += "\r\n"; // Signifies end of handshake
//Can not use UTF-8 here because we might lose bytes in response during conversion
conn.socketChannel().write(ByteBuffer.wrap(responseHandshake.getBytes()));
//Only set when Draft 76
if(responseChallenge!=null){
conn.socketChannel().write(ByteBuffer.wrap(responseChallenge));
}
return true;
}
// If we got to here, then the client sent an invalid handshake, and we
// return false to make the WebSocket object close the connection.
return false;
}
public void onMessage(WebSocket conn, String message) {
onClientMessage(conn, message);
}
public void onOpen(WebSocket conn) {
if (this.connections.add(conn)) {
onClientOpen(conn);
}
}
public void onClose(WebSocket conn) {
if (this.connections.remove(conn)) {
onClientClose(conn);
}
}
private byte[] getPart(String key) {
long keyNumber = Long.parseLong(key.replaceAll("[^0-9]",""));
long keySpace = key.split("\u0020").length - 1;
long part = new Long(keyNumber / keySpace);
return new byte[] {
(byte)( part >> 24 ),
(byte)( (part << 8) >> 24 ),
(byte)( (part << 16) >> 24 ),
(byte)( (part << 24) >> 24 )
};
}
// ABTRACT METHODS /////////////////////////////////////////////////////////
public abstract void onClientOpen(WebSocket conn);
public abstract void onClientClose(WebSocket conn);
public abstract void onClientMessage(WebSocket conn, String message);
}