package net.tootallnate.websocket;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
/**
* <tt>WebSocketServer</tt> is an abstract class that only takes care of the
* HTTP handshake portion of WebSockets. It's up to a subclass to add
* functionality/purpose to the server.
*
* @author Nathan Rajlich
*/
public abstract class WebSocketServer extends WebSocketAdapter implements Runnable {
// INSTANCE PROPERTIES /////////////////////////////////////////////////////
/**
* Holds the list of active WebSocket connections. "Active" means WebSocket
* handshake is complete and socket can be written to, or read from.
*/
private final CopyOnWriteArraySet<WebSocket> connections;
/**
* The port number that this WebSocket server should listen on. Default is
* WebSocket.DEFAULT_PORT.
*/
private InetSocketAddress address;
/**
* The socket channel for this WebSocket server.
*/
private ServerSocketChannel server;
/**
* The 'Selector' used to get event keys from the underlying socket.
*/
private Selector selector;
/**
* The Draft of the WebSocket protocol the Server is adhering to.
*/
private Draft draft;
private Thread thread;
// CONSTRUCTORS ////////////////////////////////////////////////////////////
/**
* Nullary constructor. Creates a WebSocketServer that will attempt to
* listen on port WebSocket.DEFAULT_PORT.
*/
public WebSocketServer() throws UnknownHostException {
this( new InetSocketAddress( InetAddress.getLocalHost(), WebSocket.DEFAULT_PORT ) , null );
}
/**
* Creates a WebSocketServer that will attempt to listen on port
* <var>port</var>.
*
* @param port
* The port number this server should listen on.
*/
public WebSocketServer( InetSocketAddress address ) {
this( address, null );
}
/**
* Creates a WebSocketServer that will attempt to listen on port <var>port</var>,
* and comply with <tt>Draft</tt> version <var>draft</var>.
*
* @param port
* The port number this server should listen on.
* @param draft
* The version of the WebSocket protocol that this server
* instance should comply to.
*/
public WebSocketServer( InetSocketAddress address , Draft draft ) {
this.connections = new CopyOnWriteArraySet<WebSocket>();
this.draft = draft;
setAddress( address );
}
/**
* Starts the server thread that binds to the currently set port number and
* listeners for WebSocket connection requests.
* @throws IllegalStateException
*/
public void start() {
if( thread != null )
throw new IllegalStateException( "Already started" );
new Thread( this ).start();
}
/**
* Closes all connected clients sockets, then closes the underlying
* ServerSocketChannel, effectively killing the server socket thread and
* freeing the port the server was bound to.
*
* @throws IOException
* When socket related I/O errors occur.
*/
public void stop() throws IOException {
for( WebSocket ws : connections ) {
ws.close( CloseFrame.NORMAL );
}
thread.interrupt();
this.server.close();
}
/**
* Sends <var>text</var> to all currently connected WebSocket clients.
*
* @param text
* The String to send across the network.
* @throws IOException
* When socket related I/O errors occur.
*/
public void sendToAll( String text ) throws InterruptedException {
for( WebSocket c : this.connections ) {
c.send( text );
}
}
/**
* Sends <var>text</var> to all currently connected WebSocket clients,
* except for the specified <var>connection</var>.
*
* @param connection
* The {@link WebSocket} connection to ignore.
* @param text
* The String to send to every connection except <var>connection</var>.
* @throws IOException
* When socket related I/O errors occur.
*/
public void sendToAllExcept( WebSocket connection, String text ) throws InterruptedException {
if( connection == null ) {
throw new NullPointerException( "'connection' cannot be null" );
}
for( WebSocket c : this.connections ) {
if( !connection.equals( c ) ) {
c.send( text );
}
}
}
/**
* Sends <var>text</var> to all currently connected WebSocket clients,
* except for those found in the Set <var>connections</var>.
*
* @param connections
* @param text
* @throws IOException
* When socket related I/O errors occur.
*/
public void sendToAllExcept( Set<WebSocket> connections, String text ) throws InterruptedException {
if( connections == null ) {
throw new NullPointerException( "'connections' cannot be null" );
}
for( WebSocket c : this.connections ) {
if( !connections.contains( c ) ) {
c.send( text );
}
}
}
/**
* Returns a WebSocket[] of currently connected clients.
*
* @return The currently connected clients in a WebSocket[].
*/
public Set<WebSocket> connections() {
return Collections.unmodifiableSet( this.connections );
}
/**
* Sets the port that this WebSocketServer should listen on.
*
* @param port
* The port number to listen on.
*/
public void setAddress( InetSocketAddress address ) {
this.address = address;
}
public InetSocketAddress getAddress() {
return this.address;
}
/**
* Gets the port number that this server listens on.
*
* @return The port number.
*/
public int getPort() {
return getAddress().getPort();
}
public Draft getDraft(){
return this.draft;
}
// Runnable IMPLEMENTATION /////////////////////////////////////////////////
public void run() {
if( thread != null )
throw new IllegalStateException( "This instance of " + getClass().getSimpleName() + " can only be started once the same time." );
thread = Thread.currentThread();
try {
server = ServerSocketChannel.open();
server.configureBlocking( false );
server.socket().bind( address );
//InetAddress.getLocalHost()
selector = Selector.open();
server.register( selector, server.validOps() );
} catch ( IOException ex ) {
onError( null, ex );
return;
}
while ( !thread.isInterrupted() ) {
SelectionKey key = null;
WebSocket conn = null;
try {
selector.select();
Set<SelectionKey> keys = selector.selectedKeys();
Iterator<SelectionKey> i = keys.iterator();
while ( i.hasNext() ) {
key = i.next();
// Remove the current key
i.remove();
// if isAcceptable == true
// then a client required a connection
if( key.isAcceptable() ) {
SocketChannel client = server.accept();
client.configureBlocking( false );
WebSocket c = new WebSocket( this, Collections.singletonList( draft ), client.socket().getChannel() );
client.register( selector, SelectionKey.OP_READ, c );
}
// if isReadable == true
// then the server is ready to read
if( key.isReadable() ) {
conn = (WebSocket) key.attachment();
conn.handleRead();
}
// if isWritable == true
// then we need to send the rest of the data to the client
if( key.isValid() && key.isWritable() ) {
conn = (WebSocket) key.attachment();
conn.flush();
key.channel().register( selector, SelectionKey.OP_READ, conn );
}
}
Iterator<WebSocket> it = this.connections.iterator();
while ( it.hasNext() ) {
// We have to do this check here, and not in the thread that
// adds the buffered data to the WebSocket, because the
// Selector is not thread-safe, and can only be accessed
// by this thread.
conn = it.next();
if( conn.hasBufferedData() ) {
conn.flush();
// key.channel().register( selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE, conn );
}
}
} catch ( IOException ex ) {
if( key != null )
key.cancel();
onError( conn, ex );// conn may be null here
if( conn != null ) {
conn.close( CloseFrame.ABNROMAL_CLOSE );
}
}
}
}
/**
* Gets the XML string that should be returned if a client requests a Flash
* security policy.
*
* The default implementation allows access from all remote domains, but
* only on the port that this WebSocketServer is listening on.
*
* This is specifically implemented for gitime's WebSocket client for Flash:
* http://github.com/gimite/web-socket-js
*
* @return An XML String that comforms to Flash's security policy. You MUST
* not include the null char at the end, it is appended automatically.
*/
protected String getFlashSecurityPolicy() {
return "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"" + getPort() + "\" /></cross-domain-policy>";
}
@Override
public void onMessage( WebSocket conn, String message ) {
onClientMessage( conn, message );
}
@Override
public void onOpen( WebSocket conn, Handshakedata handshake ) {
if( this.connections.add( conn ) ) {
onClientOpen( conn, handshake );
}
}
@Override
public void onClose( WebSocket conn, int code, String reason, boolean remote ) {
if( this.connections.remove( conn ) ) {
onClientClose( conn, code, reason, remote );
}
}
@Override
public void onWriteDemand( WebSocket conn ) {
selector.wakeup();
}
// ABTRACT METHODS /////////////////////////////////////////////////////////
public abstract void onClientOpen( WebSocket conn, Handshakedata handshake );
public abstract void onClientClose( WebSocket conn, int code, String reason, boolean remote );
public abstract void onClientMessage( WebSocket conn, String message );
/**
* @param conn
* may be null if the error does not belong to a single connection
*/
public abstract void onError( WebSocket conn, Exception ex );
}