/*
* Conditions Of Use
*
* This software was developed by employees of the National Institute of
* Standards and Technology (NIST), an agency of the Federal Government.
* Pursuant to title 15 Untied States Code Section 105, works of NIST
* employees are not subject to copyright protection in the United States
* and are considered to be in the public domain. As a result, a formal
* license is not needed to use the software.
*
* This software is provided by NIST as a service and is expressly
* provided "AS IS." NIST MAKES NO WARRANTY OF ANY KIND, EXPRESS, IMPLIED
* OR STATUTORY, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTY OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT
* AND DATA ACCURACY. NIST does not warrant or make any representations
* regarding the use of the software or the results thereof, including but
* not limited to the correctness, accuracy, reliability or usefulness of
* the software.
*
* Permission to use this software is contingent upon your acceptance
* of the terms of this agreement
*
* .
*
*/
package gov.nist.javax.sip.stack;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import gov.nist.core.CommonLogger;
import gov.nist.core.LogLevels;
import gov.nist.core.StackLogger;
/**
*
* Decodes a web socket frame from wire protocol version 8 format. This code was originally based on <a
* href="https://github.com/joewalnes/webbit">webbit</a>.
*
* @author vladimirralev
*
*/
public class WebSocketCodec {
private static StackLogger logger = CommonLogger
.getLogger(WebSocketCodec.class);
private static final byte OPCODE_CONT = 0x0;
private static final byte OPCODE_TEXT = 0x1;
private static final byte OPCODE_BINARY = 0x2;
private static final byte OPCODE_CLOSE = 0x8;
private static final byte OPCODE_PING = 0x9;
private static final byte OPCODE_PONG = 0xA;
private static final byte[] trivialMask = new byte[] {1,1,1,1};
// Websocket metadata
private int fragmentedFramesCount;
private boolean frameFinalFlag;
private int frameRsv;
private int frameOpcode;
private long framePayloadLength;
private byte[] maskingKey = new byte[4];
private final boolean allowExtensions;
private final boolean maskedPayload;
private boolean closeOpcodeReceived;
// THe payload inside the websocket frame starts at this index
private int payloadStartIndex = -1;
// Buffering incomplete and overflowing frames
private byte[] decodeBuffer = new byte[2048];
private int writeIndex = 0;
private int readIndex;
// Total webscoket frame (metadata + payload)
private long totalPacketLength = -1;
public WebSocketCodec(boolean maskedPayload, boolean allowExtensions) {
this.maskedPayload = maskedPayload;
this.allowExtensions = allowExtensions;
}
private byte readNextByte() {
if(readIndex >= writeIndex) {
throw new IllegalStateException();
}
return this.decodeBuffer[readIndex++];
}
public byte[] decode(InputStream is)
throws Exception {
do {
// Attempt to resize the decode buffer when it's approaching capacity
int bytesLeft = decodeBuffer.length - writeIndex;
int availToRead = is.available();
if(availToRead > bytesLeft - 1) {
int newSize = Math.max(2*decodeBuffer.length, 4*availToRead);
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("Increasing buffer size from " + decodeBuffer.length +
" avail " + availToRead + " newSize " + newSize);
}
byte[] resizeBuffer = new byte[newSize];
System.arraycopy(this.decodeBuffer, 0, resizeBuffer, 0, writeIndex);
this.decodeBuffer = resizeBuffer;
}
int bytesRead = is.read(decodeBuffer, writeIndex, bytesLeft);
if(bytesRead < 0) bytesRead = 0;
// Update the count in the buffer
writeIndex += bytesRead;
} while(is.available()>0);
// Start over from scratch. If the frame is big this may be repeated a few times O(logN) and doesn't affect performance
readIndex = 0;
// All TCP slow-start algorithms will be cut off right here without further analysis
if(writeIndex<4) {
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("Abort decode. Write index is at " + writeIndex);
}
return null;
}
byte b = readNextByte();
frameFinalFlag = (b & 0x80) != 0;
frameRsv = (b & 0x70) >> 4;
frameOpcode = b & 0x0F;
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("Decoding WebSocket Frame opCode=" + frameOpcode);
}
if(frameOpcode == 8) {
//https://code.google.com/p/chromium/issues/detail?id=388243#c15
this.closeOpcodeReceived = true;
}
// MASK, PAYLOAD LEN 1
b = readNextByte();
boolean frameMasked = (b & 0x80) != 0;
int framePayloadLen1 = b & 0x7F;
if (frameRsv != 0 && !allowExtensions) {
protocolViolation("RSV != 0 and no extension negotiated, RSV:" + frameRsv);
return null;
}
if (maskedPayload && !frameMasked) {
protocolViolation("unmasked client to server frame");
return null;
}
protocolChecks();
try {
// Read frame payload length
if (framePayloadLen1 == 126) {
int byte1 = 0xff & readNextByte();
int byte2 = 0xff & readNextByte();
int value = (byte1<<8) | byte2;
framePayloadLength = value;
} else if (framePayloadLen1 == 127) {
long value = 0;
for(int q=0;q<8;q++) {
byte nextByte = readNextByte();
long valuePart = ((long)(0xff&nextByte));
valuePart <<= (7-q)*8;
value |= valuePart;
}
framePayloadLength = value;
if (framePayloadLength < 65536) {
protocolViolation("invalid data frame length (not using minimal length encoding): " + framePayloadLength);
return null;
}
} else {
framePayloadLength = framePayloadLen1;
}
if(framePayloadLength < 0) {
protocolViolation("Negative payload size: " + framePayloadLength);
}
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("Decoding WebSocket Frame length=" + framePayloadLength);
}
// Analyze the mask
if (frameMasked) {
for(int q=0; q<4 ;q++)
maskingKey[q] = readNextByte();
}
}
catch (IllegalStateException e) {
// the stream has ended we don't have enough data to continue
return null;
}
// Remember the payload position
payloadStartIndex = readIndex;
totalPacketLength = readIndex + framePayloadLength;
// Check if we have enough data at all
if(writeIndex < totalPacketLength) {
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("Abort decode. Write index is at " + writeIndex + " and totalPacketLength is " + totalPacketLength);
}
return null; // wait for more data
}
// Unmask data if needed and only if the condition above is true
if (frameMasked) {
unmask(decodeBuffer, payloadStartIndex, (int) (payloadStartIndex + framePayloadLength));
}
// Finally isolate the unmasked payload, the bytes are plaintext here
byte[] plainTextBytes = new byte[(int) framePayloadLength];
System.arraycopy(decodeBuffer, payloadStartIndex, plainTextBytes, 0, (int) framePayloadLength);
// Now move the pending data to the begining of the buffer so we can continue having good stream
for(int q=1; q<writeIndex - totalPacketLength; q++) {
decodeBuffer[q] = decodeBuffer[(int)totalPacketLength + q];
}
writeIndex -= totalPacketLength;
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("writeIndex = " + writeIndex + " " + totalPacketLength);
}
// All done, we are ready to be called again
return plainTextBytes;
}
protected static byte[] encode(byte[] msg, int rsv, boolean fin, boolean maskPayload) throws Exception {
return encode(msg, rsv, fin, maskPayload, OPCODE_TEXT);
}
protected static byte[] encode(byte[] msg, int rsv, boolean fin, boolean maskPayload, byte opcode) throws Exception {
ByteArrayOutputStream frame = new ByteArrayOutputStream();
long length = msg.length;
if(logger.isLoggingEnabled(LogLevels.TRACE_DEBUG)) {
logger.logDebug("Encoding WebSocket Frame opCode=" + opcode + " length=" + length);
}
int b0 = 0;
if (fin) {
b0 |= 1 << 7;
}
b0 |= rsv % 8 << 4;
b0 |= opcode % 128;
if (length <= 125) {
frame.write(b0);
byte b = (byte) (maskPayload ? 0x80 | (byte) length : (byte) length);
frame.write(b);
} else if (length <= 0xFFFF) {
frame.write(b0);
frame.write(maskPayload ? 0xFE : 126);
frame.write((byte)(length >>> 8));
frame.write((byte)length);
} else {
frame.write(b0);
frame.write(maskPayload ? 0xFF : 127);
for(int q=0;q<8;q++) {
byte b = (byte)(length>>>(7-q)*8);
frame.write(b);
}
}
if(maskPayload) {
frame.write(trivialMask);
applyMask(msg, 0, msg.length, trivialMask);
}
frame.write(msg);
return frame.toByteArray();
}
private void unmask(byte[] frame, int startIndex, int endIndex) {
applyMask(frame, startIndex, endIndex, maskingKey);
}
public static void applyMask(byte[] frame, int startIndex, int endIndex, byte[] mask) {
for (int i = 0; i < endIndex-startIndex; i++) {
frame[startIndex+i] = (byte) (frame[startIndex+i] ^ mask[i % 4]);
}
}
private void protocolViolation(String reason) {
throw new RuntimeException(reason);
}
private void protocolChecks() {
if (frameOpcode > 7) { // control frame (have MSB in opcode set)
// control frames MUST NOT be fragmented
if (!frameFinalFlag) {
protocolViolation("fragmented control frame");
}
// check for reserved control frame opcodes
if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
protocolViolation("control frame using reserved opcode " + frameOpcode);
}
} else { // data frame
// check for reserved data frame opcodes
if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
protocolViolation("data frame using reserved opcode " + frameOpcode);
}
// check opcode vs message fragmentation state 1/2
if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
protocolViolation("received continuation data frame outside fragmented message");
}
// check opcode vs message fragmentation state 2/2
if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
protocolViolation("received non-continuation data frame while inside fragmented message");
}
}
}
public boolean isCloseOpcodeReceived() {
return this.closeOpcodeReceived;
}
}