/****************************************************************************** * * Copyright 2011-2012 Tavendo GmbH * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ******************************************************************************/ package org.magnum.soda.transport.wamp; import java.io.UnsupportedEncodingException; import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; import java.util.HashMap; import java.util.Map; import android.os.Handler; import android.os.Message; import android.util.Log; import android.util.Pair; /** * WebSocket reader, the receiving leg of a WebSockets connection. * This runs on it's own background thread and posts messages to master * thread's message queue for there to be consumed by the application. * The only method that needs to be called (from foreground thread) is quit(), * which gracefully shuts down the background receiver thread. */ public class WebSocketReader extends Thread { private static final boolean DEBUG = true; private static final String TAG = WebSocketReader.class.getName(); private final Handler mMaster; private final SocketChannel mSocket; private final WebSocketOptions mOptions; private final ByteBuffer mFrameBuffer; private NoCopyByteArrayOutputStream mMessagePayload; private final static int STATE_CLOSED = 0; private final static int STATE_CONNECTING = 1; private final static int STATE_CLOSING = 2; private final static int STATE_OPEN = 3; private boolean mStopped = false; private int mState; private boolean mInsideMessage = false; private int mMessageOpcode; /// Frame currently being received. private FrameHeader mFrameHeader; private Utf8Validator mUtf8Validator = new Utf8Validator(); /** * WebSockets frame metadata. */ private static class FrameHeader { public int mOpcode; public boolean mFin; @SuppressWarnings("unused") public int mReserved; public int mHeaderLen; public int mPayloadLen; public int mTotalLen; public byte[] mMask; } /** * Create new WebSockets background reader. * * @param master The message handler of master (foreground thread). * @param socket The socket channel created on foreground thread. */ public WebSocketReader(Handler master, SocketChannel socket, WebSocketOptions options, String threadName) { super(threadName); mMaster = master; mSocket = socket; mOptions = options; mFrameBuffer = ByteBuffer.allocateDirect(options.getMaxFramePayloadSize() + 14); mMessagePayload = new NoCopyByteArrayOutputStream(options.getMaxMessagePayloadSize()); mFrameHeader = null; mState = STATE_CONNECTING; if (DEBUG) Log.d(TAG, "created"); } /** * Graceful shutdown of background reader thread (called from master). */ public void quit() { mStopped = true; if (DEBUG) Log.d(TAG, "quit"); } /** * Notify the master (foreground thread) of WebSockets message received * and unwrapped. * * @param message Message to send to master. */ protected void notify(Object message) { Message msg = mMaster.obtainMessage(); msg.obj = message; mMaster.sendMessage(msg); } /** * Process incoming WebSockets data (after handshake). */ private boolean processData() throws Exception { // outside frame? if (mFrameHeader == null) { // need at least 2 bytes from WS frame header to start processing if (mFrameBuffer.position() >= 2) { byte b0 = mFrameBuffer.get(0); boolean fin = (b0 & 0x80) != 0; int rsv = (b0 & 0x70) >> 4; int opcode = b0 & 0x0f; byte b1 = mFrameBuffer.get(1); boolean masked = (b1 & 0x80) != 0; int payload_len1 = b1 & 0x7f; // now check protocol compliance if (rsv != 0) { throw new WebSocketException("RSV != 0 and no extension negotiated"); } if (masked) { // currently, we don't allow this. need to see whats the final spec. throw new WebSocketException("masked server frame"); } if (opcode > 7) { // control frame if (!fin) { throw new WebSocketException("fragmented control frame"); } if (payload_len1 > 125) { throw new WebSocketException("control frame with payload length > 125 octets"); } if (opcode != 8 && opcode != 9 && opcode != 10) { throw new WebSocketException("control frame using reserved opcode " + opcode); } if (opcode == 8 && payload_len1 == 1) { throw new WebSocketException("received close control frame with payload len 1"); } } else { // message frame if (opcode != 0 && opcode != 1 && opcode != 2) { throw new WebSocketException("data frame using reserved opcode " + opcode); } if (!mInsideMessage && opcode == 0) { throw new WebSocketException("received continuation data frame outside fragmented message"); } if (mInsideMessage && opcode != 0) { throw new WebSocketException("received non-continuation data frame while inside fragmented message"); } } int mask_len = masked ? 4 : 0; int header_len = 0; if (payload_len1 < 126) { header_len = 2 + mask_len; } else if (payload_len1 == 126) { header_len = 2 + 2 + mask_len; } else if (payload_len1 == 127) { header_len = 2 + 8 + mask_len; } else { // should not arrive here throw new Exception("logic error"); } // continue when complete frame header is available if (mFrameBuffer.position() >= header_len) { // determine frame payload length int i = 2; long payload_len = 0; if (payload_len1 == 126) { payload_len = ((0xff & mFrameBuffer.get(i)) << 8) | (0xff & mFrameBuffer.get(i+1)); if (payload_len < 126) { throw new WebSocketException("invalid data frame length (not using minimal length encoding)"); } i += 2; } else if (payload_len1 == 127) { if ((0x80 & mFrameBuffer.get(i+0)) != 0) { throw new WebSocketException("invalid data frame length (> 2^63)"); } payload_len = ((0xff & mFrameBuffer.get(i+0)) << 56) | ((0xff & mFrameBuffer.get(i+1)) << 48) | ((0xff & mFrameBuffer.get(i+2)) << 40) | ((0xff & mFrameBuffer.get(i+3)) << 32) | ((0xff & mFrameBuffer.get(i+4)) << 24) | ((0xff & mFrameBuffer.get(i+5)) << 16) | ((0xff & mFrameBuffer.get(i+6)) << 8) | ((0xff & mFrameBuffer.get(i+7)) ); if (payload_len < 65536) { throw new WebSocketException("invalid data frame length (not using minimal length encoding)"); } i += 8; } else { payload_len = payload_len1; } // immediately bail out on frame too large if (payload_len > mOptions.getMaxFramePayloadSize()) { throw new WebSocketException("frame payload too large"); } // save frame header metadata mFrameHeader = new FrameHeader(); mFrameHeader.mOpcode = opcode; mFrameHeader.mFin = fin; mFrameHeader.mReserved = rsv; mFrameHeader.mPayloadLen = (int) payload_len; mFrameHeader.mHeaderLen = header_len; mFrameHeader.mTotalLen = mFrameHeader.mHeaderLen + mFrameHeader.mPayloadLen; if (masked) { mFrameHeader.mMask = new byte[4]; for (int j = 0; j < 4; ++j) { mFrameHeader.mMask[i] = (byte) (0xff & mFrameBuffer.get(i + j)); } i += 4; } else { mFrameHeader.mMask = null; } // continue processing when payload empty or completely buffered return mFrameHeader.mPayloadLen == 0 || mFrameBuffer.position() >= mFrameHeader.mTotalLen; } else { // need more data return false; } } else { // need more data return false; } } else { /// \todo refactor this for streaming processing, incl. fail fast on invalid UTF-8 within frame already // within frame // see if we buffered complete frame if (mFrameBuffer.position() >= mFrameHeader.mTotalLen) { // cut out frame payload byte[] framePayload = null; int oldPosition = mFrameBuffer.position(); if (mFrameHeader.mPayloadLen > 0) { framePayload = new byte[mFrameHeader.mPayloadLen]; mFrameBuffer.position(mFrameHeader.mHeaderLen); mFrameBuffer.get(framePayload, 0, (int) mFrameHeader.mPayloadLen); } mFrameBuffer.position(mFrameHeader.mTotalLen); mFrameBuffer.limit(oldPosition); mFrameBuffer.compact(); if (mFrameHeader.mOpcode > 7) { // control frame if (mFrameHeader.mOpcode == 8) { int code = 1005; // CLOSE_STATUS_CODE_NULL : no status code received String reason = null; if (mFrameHeader.mPayloadLen >= 2) { // parse and check close code code = (framePayload[0] & 0xff) * 256 + (framePayload[1] & 0xff); if (code < 1000 || (code >= 1000 && code <= 2999 && code != 1000 && code != 1001 && code != 1002 && code != 1003 && code != 1007 && code != 1008 && code != 1009 && code != 1010 && code != 1011) || code >= 5000) { throw new WebSocketException("invalid close code " + code); } // parse and check close reason if (mFrameHeader.mPayloadLen > 2) { byte[] ra = new byte[mFrameHeader.mPayloadLen - 2]; System.arraycopy(framePayload, 2, ra, 0, mFrameHeader.mPayloadLen - 2); Utf8Validator val = new Utf8Validator(); val.validate(ra); if (!val.isValid()) { throw new WebSocketException("invalid close reasons (not UTF-8)"); } else { reason = new String(ra, "UTF-8"); } } } onClose(code, reason); } else if (mFrameHeader.mOpcode == 9) { // dispatch WS ping onPing(framePayload); } else if (mFrameHeader.mOpcode == 10) { // dispatch WS pong onPong(framePayload); } else { // should not arrive here (handled before) throw new Exception("logic error"); } } else { // message frame if (!mInsideMessage) { // new message started mInsideMessage = true; mMessageOpcode = mFrameHeader.mOpcode; if (mMessageOpcode == 1 && mOptions.getValidateIncomingUtf8()) { mUtf8Validator.reset(); } } if (framePayload != null) { // immediately bail out on message too large if (mMessagePayload.size() + framePayload.length > mOptions.getMaxMessagePayloadSize()) { throw new WebSocketException("message payload too large"); } // validate incoming UTF-8 if (mMessageOpcode == 1 && mOptions.getValidateIncomingUtf8() && !mUtf8Validator.validate(framePayload)) { throw new WebSocketException("invalid UTF-8 in text message payload"); } // buffer frame payload for message mMessagePayload.write(framePayload); } // on final frame .. if (mFrameHeader.mFin) { if (mMessageOpcode == 1) { // verify that UTF-8 ends on codepoint if (mOptions.getValidateIncomingUtf8() && !mUtf8Validator.isValid()) { throw new WebSocketException("UTF-8 text message payload ended within Unicode code point"); } // deliver text message if (mOptions.getReceiveTextMessagesRaw()) { // dispatch WS text message as raw (but validated) UTF-8 onRawTextMessage(mMessagePayload.toByteArray()); } else { // dispatch WS text message as Java String (previously already validated) String s = new String(mMessagePayload.toByteArray(), "UTF-8"); onTextMessage(s); } } else if (mMessageOpcode == 2) { // dispatch WS binary message onBinaryMessage(mMessagePayload.toByteArray()); } else { // should not arrive here (handled before) throw new Exception("logic error"); } // ok, message completed - reset all mInsideMessage = false; mMessagePayload.reset(); } } // reset frame mFrameHeader = null; // reprocess if more data left return mFrameBuffer.position() > 0; } else { // need more data return false; } } } /** * WebSockets handshake reply from server received, default notifies master. * * @param success Success handshake flag */ protected void onHandshake(boolean success) { notify(new WebSocketMessage.ServerHandshake(success)); } /** * WebSockets close received, default notifies master. */ protected void onClose(int code, String reason) { notify(new WebSocketMessage.Close(code, reason)); } /** * WebSockets ping received, default notifies master. * * @param payload Ping payload or null. */ protected void onPing(byte[] payload) { notify(new WebSocketMessage.Ping(payload)); } /** * WebSockets pong received, default notifies master. * * @param payload Pong payload or null. */ protected void onPong(byte[] payload) { notify(new WebSocketMessage.Pong(payload)); } /** * WebSockets text message received, default notifies master. * This will only be called when the option receiveTextMessagesRaw * HAS NOT been set. * * @param payload Text message payload as Java String decoded * from raw UTF-8 payload or null (empty payload). */ protected void onTextMessage(String payload) { notify(new WebSocketMessage.TextMessage(payload)); } /** * WebSockets text message received, default notifies master. * This will only be called when the option receiveTextMessagesRaw * HAS been set. * * @param payload Text message payload as raw UTF-8 octets or * null (empty payload). */ protected void onRawTextMessage(byte[] payload) { notify(new WebSocketMessage.RawTextMessage(payload)); } /** * WebSockets binary message received, default notifies master. * * @param payload Binary message payload or null (empty payload). */ protected void onBinaryMessage(byte[] payload) { notify(new WebSocketMessage.BinaryMessage(payload)); } /** * Process WebSockets handshake received from server. */ private boolean processHandshake() throws UnsupportedEncodingException { boolean res = false; for (int pos = mFrameBuffer.position() - 4; pos >= 0; --pos) { if (mFrameBuffer.get(pos+0) == 0x0d && mFrameBuffer.get(pos+1) == 0x0a && mFrameBuffer.get(pos+2) == 0x0d && mFrameBuffer.get(pos+3) == 0x0a) { /// \todo process & verify handshake from server /// \todo forward subprotocol, if any int oldPosition = mFrameBuffer.position(); // Check HTTP status code boolean serverError = false; if (mFrameBuffer.get(0) == 'H' && mFrameBuffer.get(1) == 'T' && mFrameBuffer.get(2) == 'T' && mFrameBuffer.get(3) == 'P') { Pair<Integer, String> status = parseHttpStatus(); if (status.first >= 300) { // Invalid status code for success connection notify(new WebSocketMessage.ServerError(status.first, status.second)); serverError = true; } } mFrameBuffer.position(pos + 4); mFrameBuffer.limit(oldPosition); mFrameBuffer.compact(); if (!serverError) { // process further when data after HTTP headers left in buffer res = mFrameBuffer.position() > 0; mState = STATE_OPEN; } else { res = true; mState = STATE_CLOSED; mStopped = true; } onHandshake(!serverError); break; } } return res; } @SuppressWarnings("unused") private Map<String, String> parseHttpHeaders(byte[] buffer) throws UnsupportedEncodingException { // TODO: use utf-8 validator? String s = new String(buffer, "UTF-8"); Map<String, String> headers = new HashMap<String, String>(); String[] lines = s.split("\r\n"); for (String line : lines) { if (line.length() > 0) { String[] h = line.split(": "); if (h.length == 2) { headers.put(h[0], h[1]); Log.w(TAG, String.format("'%s'='%s'", h[0], h[1])); } } } return headers; } private Pair<Integer, String> parseHttpStatus() throws UnsupportedEncodingException { int beg, end; // Find first space for (beg = 4; beg < mFrameBuffer.position(); ++beg) { if (mFrameBuffer.get(beg) == ' ') break; } // Find second space for (end = beg + 1; end < mFrameBuffer.position(); ++end) { if (mFrameBuffer.get(end) == ' ') break; } // Parse status code between them ++beg; int statusCode = 0; for (int i = 0; beg + i < end; ++i) { int digit = (mFrameBuffer.get(beg + i) - 0x30); statusCode *= 10; statusCode += digit; } // Find end of line to extract error message ++end; int eol; for (eol = end; eol < mFrameBuffer.position(); ++eol) { if (mFrameBuffer.get(eol) == 0x0d) break; } int statusMessageLength = eol - end; byte[] statusBuf = new byte[statusMessageLength]; mFrameBuffer.position(end); mFrameBuffer.get(statusBuf, 0, statusMessageLength); String statusMessage = new String(statusBuf, "UTF-8"); if (DEBUG) Log.w(TAG, String.format("Status: %d (%s)", statusCode, statusMessage)); return new Pair<Integer, String>(statusCode, statusMessage); } /** * Consume data buffered in mFrameBuffer. */ private boolean consumeData() throws Exception { if (mState == STATE_OPEN || mState == STATE_CLOSING) { return processData(); } else if (mState == STATE_CONNECTING) { return processHandshake(); } else if (mState == STATE_CLOSED) { return false; } else { // should not arrive here return false; } } /** * Run the background reader thread loop. */ @Override public void run() { if (DEBUG) Log.d(TAG, "running"); try { mFrameBuffer.clear(); do { // blocking read on socket int len = mSocket.read(mFrameBuffer); if (len > 0) { // process buffered data while (consumeData()) { } } else if (len < 0) { if (DEBUG) Log.d(TAG, "run() : ConnectionLost"); notify(new WebSocketMessage.ConnectionLost()); mStopped = true; } } while (!mStopped); } catch (WebSocketException e) { if (DEBUG) Log.d(TAG, "run() : WebSocketException (" + e.toString() + ")"); // wrap the exception and notify master notify(new WebSocketMessage.ProtocolViolation(e)); } catch (SocketException e) { if (DEBUG) Log.d(TAG, "run() : SocketException (" + e.toString() + ")"); // wrap the exception and notify master notify(new WebSocketMessage.ConnectionLost());; } catch (Exception e) { if (DEBUG) Log.d(TAG, "run() : Exception (" + e.toString() + ")"); // wrap the exception and notify master notify(new WebSocketMessage.Error(e)); } finally { mStopped = true; } if (DEBUG) Log.d(TAG, "ended"); } }