// ---------------------------------------------------------------------------
// jWebSocket - Copyright (c) 2010 Innotrade GmbH
// ---------------------------------------------------------------------------
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU Lesser General Public License as published by the
// Free Software Foundation; either version 3 of the License, or (at your
// option) any later version.
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
// more details.
// You should have received a copy of the GNU Lesser General Public License along
// with this program; if not, see <http://www.gnu.org/licenses/lgpl.html>.
// ---------------------------------------------------------------------------
package org.jwebsocket.client.java;
import java.io.*;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocketFactory;
import javolution.util.FastList;
import javolution.util.FastMap;
import org.jwebsocket.api.WebSocketClient;
import org.jwebsocket.api.WebSocketClientEvent;
import org.jwebsocket.api.WebSocketClientListener;
import org.jwebsocket.api.WebSocketPacket;
import org.jwebsocket.api.WebSocketStatus;
import org.jwebsocket.client.token.WebSocketClientTokenEvent;
import org.jwebsocket.config.JWebSocketCommonConstants;
import org.jwebsocket.kit.RawPacket;
import org.jwebsocket.kit.WebSocketException;
import org.jwebsocket.kit.WebSocketHandshake;
import org.jwebsocket.kit.WebSocketProtocolHandler;
/**
* Base {@code WebSocket} implementation based on
* http://weberknecht.googlecode.com by Roderick Baier. This uses thread model
* for handling WebSocket connection which is defined by the <tt>WebSocket</tt>
* protocol specification. {@linkplain http://www.whatwg.org/specs/web-socket-protocol/}
* {@linkplain http://www.w3.org/TR/websockets/}
*
* @author Roderick Baier
* @author agali
* @author puran
* @author jang
* @version $Id:$
*/
public class BaseWebSocket implements WebSocketClient {
/**
* WebSocket connection url
*/
private URI mURL = null;
/**
* list of the listeners registered
*/
private List<WebSocketClientListener> mListeners = new FastList<WebSocketClientListener>();
/**
* flag for connection test
*/
private volatile boolean mConnected = false;
/**
* TCP socket
*/
private Socket mSocket = null;
/**
* IO streams
*/
private InputStream mInput = null;
private PrintStream mOutput = null;
/**
* Data receiver
*/
private WebSocketReceiver mReceiver = null;
/**
* represents the WebSocket status
*/
private WebSocketStatus mStatus = WebSocketStatus.CLOSED;
private List<SubProtocol> mSubprotocols;
private SubProtocol mNegotiatedSubprotocol;
private String mDraft = JWebSocketCommonConstants.WS_DRAFT_DEFAULT;
/**
* Base constructor
*/
public BaseWebSocket() {
}
/**
* {@inheritDoc}
*/
@Override
public void open(String aURIString) throws WebSocketException {
URI lURI = null;
try {
lURI = new URI(aURIString);
} catch (URISyntaxException lEx) {
throw new WebSocketException("Error parsing WebSocket URL:" + aURIString, lEx);
}
this.mURL = lURI;
String lSubProtocol = makeSubprotocolHeader();
WebSocketHandshake lHandshake = new WebSocketHandshake(mURL, lSubProtocol, mDraft);
try {
mSocket = createSocket();
mInput = mSocket.getInputStream();
mOutput = new PrintStream(mSocket.getOutputStream());
mOutput.write(lHandshake.getHandshake());
boolean lHandshakeComplete = false;
boolean lHeader = true;
int len = 1000;
byte[] lBuffer = new byte[len];
int lPos = 0;
ArrayList<String> lHandshakeLines = new ArrayList<String>();
byte[] lServerResponse = new byte[16];
while (!lHandshakeComplete) {
mStatus = WebSocketStatus.CONNECTING;
int lB = mInput.read();
lBuffer[lPos] = (byte) lB;
lPos += 1;
if (!lHeader) {
lServerResponse[lPos - 1] = (byte) lB;
if (lPos == 16) {
lHandshakeComplete = true;
}
} else if (lBuffer[lPos - 1] == 0x0A && lBuffer[lPos - 2] == 0x0D) {
String line = new String(lBuffer, "UTF-8");
if (line.trim().equals("")) {
lHeader = false;
} else {
lHandshakeLines.add(line.trim());
}
lBuffer = new byte[len];
lPos = 0;
}
}
lHandshake.verifyServerStatusLine(lHandshakeLines.get(0));
lHandshake.verifyServerResponse(lServerResponse);
lHandshakeLines.remove(0);
Map<String, String> lHeaders = new FastMap<String, String>();
for (String lLine : lHandshakeLines) {
String[] lKeyVal = lLine.split(": ", 2);
lHeaders.put(lKeyVal[0], lKeyVal[1]);
}
lHandshake.verifyServerHandshakeHeaders(lHeaders);
// set negotiated sub protocol
if (lHeaders.containsKey("Sec-WebSocket-Protocol")) {
String llHeader = lHeaders.get("Sec-WebSocket-Protocol");
if (llHeader.indexOf('/') == -1) {
mNegotiatedSubprotocol = new SubProtocol(llHeader, JWebSocketCommonConstants.WS_FORMAT_DEFAULT);
} else {
String[] lSplit = llHeader.split("/");
mNegotiatedSubprotocol = new SubProtocol(lSplit[0], lSplit[1]);
}
} else {
// just default to 'jwebsocket.org/json'
mNegotiatedSubprotocol = new SubProtocol(JWebSocketCommonConstants.WS_SUBPROTOCOL_DEFAULT,
JWebSocketCommonConstants.WS_FORMAT_DEFAULT);
}
mReceiver = new WebSocketReceiver(mInput);
// TODO: Add event parameter
notifyOpened(null);
mReceiver.start();
mConnected = true;
mStatus = WebSocketStatus.OPEN;
} catch (IOException lIOEx) {
throw new WebSocketException("error while connecting: " + lIOEx.getMessage(), lIOEx);
}
}
@Override
public void send(byte[] aData) throws WebSocketException {
if (isHixieDraft()) {
sendInternal(aData);
} else {
WebSocketPacket lPacket = new RawPacket(aData);
lPacket.setFrameType(WebSocketProtocolHandler.toRawPacketType(mNegotiatedSubprotocol.mFormat));
sendInternal(WebSocketProtocolHandler.toProtocolPacket(lPacket));
}
}
/**
* {@inheritDoc}
*/
@Override
public void send(String aData, String aEncoding) throws WebSocketException {
byte[] lData;
try {
lData = aData.getBytes(aEncoding);
} catch (UnsupportedEncodingException lEx) {
throw new WebSocketException("Encoding exception while sending the data:" + lEx.getMessage(), lEx);
}
send(lData);
}
/**
* {@inheritDoc}
*/
@Override
public void send(WebSocketPacket aDataPacket) throws WebSocketException {
if (isHixieDraft()) {
sendInternal(aDataPacket.getByteArray());
} else {
if (isBinaryFormat() && (aDataPacket.getFrameType() != RawPacket.FRAMETYPE_BINARY)) {
// we negotiated binary format with the server
throw new WebSocketException("Only binary packets are allowed for this connection");
}
sendInternal(WebSocketProtocolHandler.toProtocolPacket(aDataPacket));
}
}
private void sendInternal(byte[] aData) throws WebSocketException {
if (!mConnected) {
throw new WebSocketException("error while sending binary data: not connected");
}
try {
if (isHixieDraft()) {
if (isBinaryFormat()) {
mOutput.write(0x80);
// TODO: what if frame is longer than 255 characters (8bit?) Refer to IETF spec!
mOutput.write(aData.length);
mOutput.write(aData);
} else {
mOutput.write(0x00);
mOutput.write(aData);
mOutput.write(0xff);
}
} else {
mOutput.write(aData);
}
mOutput.flush();
} catch (IOException lEx) {
throw new WebSocketException("error while sending socket data: ", lEx);
}
}
public void handleReceiverError() {
try {
if (mConnected) {
mStatus = WebSocketStatus.CLOSING;
close();
}
} catch (WebSocketException lWSE) {
// TODO: don't use printStackTrace
// wse.printStackTrace();
}
}
@Override
public synchronized void close() throws WebSocketException {
if (!mConnected) {
return;
}
sendCloseHandshake();
if (mReceiver.isRunning()) {
mReceiver.stopit();
}
try {
// input.close();
// output.close();
mSocket.shutdownInput();
mSocket.shutdownOutput();
mSocket.close();
mStatus = WebSocketStatus.CLOSED;
} catch (IOException lIOEx) {
throw new WebSocketException("error while closing websocket connection: ", lIOEx);
}
// TODO: add event
notifyClosed(null);
}
private void sendCloseHandshake() throws WebSocketException {
if (!mConnected) {
throw new WebSocketException("error while sending close handshake: not connected");
}
try {
if (isHixieDraft()) {
mOutput.write(0xff00);
// TODO: check if final CR/LF is required/valid!
mOutput.write("\r\n".getBytes());
// TODO: shouldn't we put a flush here?
} else {
WebSocketPacket lPacket = new RawPacket("BYE");
lPacket.setFrameType(RawPacket.FRAMETYPE_CLOSE);
send(lPacket);
}
} catch (IOException lIOEx) {
throw new WebSocketException("error while sending close handshake", lIOEx);
}
mConnected = false;
}
private Socket createSocket() throws WebSocketException {
String lScheme = mURL.getScheme();
String lHost = mURL.getHost();
int lPort = mURL.getPort();
mSocket = null;
if (lScheme != null && lScheme.equals("ws")) {
if (lPort == -1) {
lPort = 80;
}
try {
mSocket = new Socket(lHost, lPort);
} catch (UnknownHostException lUHEx) {
throw new WebSocketException("unknown host: " + lHost, lUHEx);
} catch (IOException lIOEx) {
throw new WebSocketException("error while creating socket to " + mURL, lIOEx);
}
} else if (lScheme != null && lScheme.equals("wss")) {
if (lPort == -1) {
lPort = 443;
}
try {
SocketFactory lFactory = SSLSocketFactory.getDefault();
mSocket = lFactory.createSocket(lHost, lPort);
} catch (UnknownHostException lUHEx) {
throw new WebSocketException("unknown host: " + lHost, lUHEx);
} catch (IOException lIOEx) {
throw new WebSocketException("error while creating secure socket to " + mURL, lIOEx);
}
} else {
throw new WebSocketException("unsupported protocol: " + lScheme);
}
return mSocket;
}
/**
* {@inheritDoc }
*/
@Override
public boolean isConnected() {
return mConnected && mStatus.equals(WebSocketStatus.OPEN);
}
/**
* {@inheritDoc }
*/
public WebSocketStatus getConnectionStatus() {
return mStatus;
}
/**
* @return the client socket
*/
public Socket getConnectionSocket() {
return mSocket;
}
/**
* {@inheritDoc}
*/
@Override
public void addListener(WebSocketClientListener aListener) {
mListeners.add(aListener);
}
/**
* {@inheritDoc}
*/
@Override
public void removeListener(WebSocketClientListener aListener) {
mListeners.remove(aListener);
}
/**
* {@inheritDoc}
*/
@Override
public List<WebSocketClientListener> getListeners() {
return Collections.unmodifiableList(mListeners);
}
/**
* {@inheritDoc}
*/
@Override
public void notifyOpened(WebSocketClientEvent aEvent) {
for (WebSocketClientListener lListener : getListeners()) {
lListener.processOpened(aEvent);
}
}
/**
* {@inheritDoc}
*/
@Override
public void notifyPacket(WebSocketClientEvent aEvent, WebSocketPacket aPacket) {
for (WebSocketClientListener lListener : getListeners()) {
lListener.processPacket(aEvent, aPacket);
}
}
/**
* {@inheritDoc}
*/
@Override
public void notifyClosed(WebSocketClientEvent aEvent) {
for (WebSocketClientListener lListener : getListeners()) {
lListener.processClosed(aEvent);
}
}
@Override
public void addSubProtocol(String aProtocolName, String aProtocolFormat) {
if (mSubprotocols == null) {
mSubprotocols = new ArrayList<SubProtocol>(5);
}
mSubprotocols.add(new SubProtocol(aProtocolName, aProtocolFormat));
}
@Override
public String getNegotiatedProtocolName() {
return mNegotiatedSubprotocol == null ? null : mNegotiatedSubprotocol.mName;
}
@Override
public String getNegotiatedProtocolFormat() {
return mNegotiatedSubprotocol == null ? null : mNegotiatedSubprotocol.mFormat;
}
@Override
public void setDraft(String aDraft) {
this.mDraft = aDraft;
}
/**
* Make a subprotocol string for Sec-WebSocket-Protocol header.
* The result is something like this:
* <pre>
* chat.example.com/json v2.chat.example.com/xml audio.chat.example.com/binary
* </pre>
*
* @return subprotocol list in one string
*/
private String makeSubprotocolHeader() {
if (mSubprotocols == null || mSubprotocols.size() < 1) {
//return JWebSocketCommonConstants.WS_SUBPROTOCOL_DEFAULT + '/' + JWebSocketCommonConstants.WS_FORMAT_DEFAULT;
return null;
} else {
StringBuilder lBuff = new StringBuilder();
for (SubProtocol lProt : mSubprotocols) {
lBuff.append(lProt.toString()).append(' ');
}
return lBuff.toString().trim();
}
}
private boolean isHixieDraft() {
return JWebSocketCommonConstants.WS_DRAFT_DEFAULT.equals(mDraft);
}
private boolean isBinaryFormat() {
return mNegotiatedSubprotocol != null
&& JWebSocketCommonConstants.WS_FORMAT_BINARY.equals(mNegotiatedSubprotocol.mFormat);
}
class SubProtocol {
private String mName;
private String mFormat;
private SubProtocol(String aName, String aFormat) {
this.mName = aName;
this.mFormat = aFormat;
}
@Override
public int hashCode() {
return mName.hashCode() * 31 + mFormat.hashCode();
}
@Override
public boolean equals(Object aObj) {
if (aObj instanceof SubProtocol) {
SubProtocol lOther = (SubProtocol) aObj;
return mName.equals(lOther.mName) && mFormat.equals(lOther.mFormat);
} else {
return super.equals(aObj);
}
}
@Override
public String toString() {
StringBuilder lBuff = new StringBuilder();
lBuff.append(mName).append('/').append(mFormat);
return lBuff.toString();
}
}
class WebSocketReceiver extends Thread {
private InputStream mIS = null;
private volatile boolean mStop = false;
public WebSocketReceiver(InputStream aInput) {
this.mIS = aInput;
}
@Override
public void run() {
try {
if (isHixieDraft()) {
readHixie();
} else {
readHybi();
}
} catch (Exception lEx) {
handleError();
}
}
private void readHixie() throws IOException {
boolean lFrameStart = false;
ByteArrayOutputStream lBuff = new ByteArrayOutputStream();
while (!mStop) {
int lB = mIS.read();
// TODO: support binary frames
if (lB == 0x00) {
lFrameStart = true;
} else if (lB == 0xff && lFrameStart == true) {
lFrameStart = false;
WebSocketClientEvent lWSCE = new WebSocketClientTokenEvent();
RawPacket lPacket = new RawPacket(lBuff.toByteArray());
lBuff.reset();
notifyPacket(lWSCE, lPacket);
} else if (lFrameStart == true) {
lBuff.write(lB);
} else if (lB == -1) {
handleError();
}
}
}
private void readHybi() throws WebSocketException, IOException {
int lPacketType;
// utilize data input stream, because it has convenient methods for reading
// signed/unsigned bytes, shorts, ints and longs
DataInputStream lDis = new DataInputStream(mIS);
ByteArrayOutputStream lBuff = new ByteArrayOutputStream();
while (!mStop) {
// begin normal packet read
int lFlags = lDis.read();
// determine fragmentation
boolean lFragmented = (0x01 & lFlags) == 0x01;
// shift 4 bits to skip the first bit and three RSVx bits
int lType = lFlags >> 4;
lPacketType = WebSocketProtocolHandler.toRawPacketType(lType);
if (lPacketType == -1) {
// Could not determine packet type, ignore the packet.
// Maybe we need a setting to decide, if such packets should abort the connection?
handleError();
} else {
// Ignore first bit. Payload length is next seven bits, unless its value is greater than 125.
long lPayloadLen = mIS.read() >> 1;
if (lPayloadLen == 126) {
// following two bytes are acutal payload length (16-bit unsigned integer)
lPayloadLen = lDis.readUnsignedShort();
} else if (lPayloadLen == 127) {
// following eight bytes are actual payload length (64-bit unsigned integer)
lPayloadLen = lDis.readLong();
}
if (lPayloadLen > 0) {
// payload length may be extremely long, so we read in loop rather
// than construct one byte[] array and fill it with read() method,
// because java does not allow longs as array size
while (lPayloadLen-- > 0) {
lBuff.write(lDis.read());
}
}
if (!lFragmented) {
if (lPacketType == RawPacket.FRAMETYPE_PING) {
// As per spec, we must respond to PING with PONG (maybe
// this should be handled higher up in the hierarchy?)
WebSocketPacket lPong = new RawPacket(lBuff.toByteArray());
lPong.setFrameType(RawPacket.FRAMETYPE_PONG);
send(lPong);
} else if (lPacketType == RawPacket.FRAMETYPE_CLOSE) {
close();
}
// Packet was read, pass it forward.
WebSocketPacket lPacket = new RawPacket(lBuff.toByteArray());
lPacket.setFrameType(lPacketType);
WebSocketClientEvent lWSCE = new WebSocketClientTokenEvent();
notifyPacket(lWSCE, lPacket);
lBuff.reset();
}
}
}
}
public void stopit() {
mStop = true;
}
public boolean isRunning() {
return !mStop;
}
private void handleError() {
stopit();
}
}
}