// --------------------------------------------------------------------------- // jWebSocket - WebSocket Handshake // Copyright (c) 2010 Innotrade GmbH, jWebSocket.org // --------------------------------------------------------------------------- // This program is free software; you can redistribute it and/or modify it // under the terms of the GNU Lesser General Public License as published by the // Free Software Foundation; either version 3 of the License, or (at your // option) any later version. // This program is distributed in the hope that it will be useful, but WITHOUT // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or // FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for // more details. // You should have received a copy of the GNU Lesser General Public License along // with this program; if not, see <http://www.gnu.org/licenses/lgpl.html>. // --------------------------------------------------------------------------- package org.jwebsocket.kit; import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; import java.math.BigInteger; import java.net.URI; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.Map; import javolution.util.FastMap; /** * Utility class for all the handshaking related request/response. * @author aschulze * @version $Id:$ */ public final class WebSocketHandshake { /** * */ public static int MAX_HEADER_SIZE = 16834; private String mKey1 = null; private String mKey2 = null; private byte[] mKey3 = null; private byte[] mExpectedServerResponse = null; private URI mURL = null; private String mOrigin = null; private String mProtocol = null; private String mDraft = null; /** * * @param aURL */ public WebSocketHandshake(URI aURL) { this(aURL, null, null); } /** * * @param aURL * @param aProtocol */ public WebSocketHandshake(URI aURL, String aProtocol) { this(aURL, aProtocol, null); } /** * * @param aURL * @param aProtocol * @param aDraft */ public WebSocketHandshake(URI aURL, String aProtocol, String aDraft) { this.mURL = aURL; this.mProtocol = aProtocol; this.mDraft = aDraft; generateKeys(); } /** * Generates the initial handshake request from a client to the jWebSocket * Server. This is send from a Java client to the server when a connection * is about to be established. The browser's implement that internally. * * @param aHost * @param aPath * @return */ // public static byte[] generateC2SRequest(URI aURI) { public static byte[] generateC2SRequest(String aHost, String aPath) { // String lPath = aURI.getPath(); // String lHost = aURI.getHost(); String lOrigin = "http://" + aHost; String lHandshake = "GET " + aPath + " HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: " + aHost + "\r\n" + "Origin: " + lOrigin + "\r\n" + "\r\n"; byte[] lBA = null; try { lBA = lHandshake.getBytes("US-ASCII"); } catch (Exception lEx) { } return lBA; } private static long calcSecKeyNum(String aKey) { StringBuilder lSB = new StringBuilder(); // StringBuuffer lSB = new StringBuuffer(); int lSpaces = 0; for (int lIdx = 0; lIdx < aKey.length(); lIdx++) { char lC = aKey.charAt(lIdx); if (lC == ' ') { lSpaces++; } else if (lC >= '0' && lC <= '9') { lSB.append(lC); } } long lRes = -1; if (lSpaces > 0) { try { lRes = Long.parseLong(lSB.toString()) / lSpaces; // log.debug("Key: " + aKey + ", Numbers: " + lSB.toString() + // ", Spaces: " + lSpaces + ", Result: " + lRes); } catch (NumberFormatException lEx) { // use default result } } return lRes; } /** * Parses the response from the client on an initial client's handshake * request. This is always performed on the server only when a client - * irrespective of if it is a Java Client or Browser Client - initiates a * connection. * * @param aReq * @return */ public static Map parseC2SRequest(byte[] aReq) { String lHost = null; String lOrigin = null; String lLocation = null; String lPath = null; String lSubProt = null; String lDraft = null; String lSecKey1 = null; String lSecKey2 = null; byte[] lSecKey3 = new byte[8]; Boolean lIsSecure = false; Long lSecNum1 = null; Long lSecNum2 = null; byte[] lSecKeyResp = new byte[8]; Map lRes = new FastMap(); int lReqLen = aReq.length; String lRequest = ""; try { lRequest = new String(aReq, "US-ASCII"); } catch (Exception lEx) { // TODO: add exception handling } if (lRequest.indexOf("policy-file-request") >= 0) { // "<policy-file-request/>" lRes.put("policy-file-request", lRequest); return lRes; } lIsSecure = (lRequest.indexOf("Sec-WebSocket") > 0); if (lIsSecure) { lReqLen -= 8; for (int lIdx = 0; lIdx < 8; lIdx++) { lSecKey3[lIdx] = aReq[lReqLen + lIdx]; } } // now parse header for correct handshake.... // get host.... int lPos = lRequest.indexOf("Host:"); lPos += 6; lHost = lRequest.substring(lPos); lPos = lHost.indexOf("\r\n"); lHost = lHost.substring(0, lPos); // get origin.... lPos = lRequest.indexOf("Origin:"); lPos += 8; lOrigin = lRequest.substring(lPos); lPos = lOrigin.indexOf("\r\n"); lOrigin = lOrigin.substring(0, lPos); // get path.... lPos = lRequest.indexOf("GET"); lPos += 4; lPath = lRequest.substring(lPos); lPos = lPath.indexOf("HTTP"); lPath = lPath.substring(0, lPos - 1); lLocation = "ws://" + lHost + lPath; // get websocket sub protocol (irrespective of Sec- prefix for older browsers) lPos = lRequest.indexOf("WebSocket-Protocol:"); if (lPos > 0) { lPos += 20; lSubProt = lRequest.substring(lPos); lPos = lSubProt.indexOf("\r\n"); lSubProt = lSubProt.substring(0, lPos); } // Sec-WebSocket-Draft: This field was introduced with hybi-03 web socket protocol draft. // See: http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-03 // // Specification proposes the use of draft number (without any prefixes or suffixes) as a value // for this field. For example: "Sec-WebSocket-Draft: 3" indicates that the communication will proceed // according to #03 draft. If the value is something that the server doesn't recognize, // then the handshake should fail and web socket connection must be aborted. // // If present, then BaseEngine & BaseConnector (their subclasses) should process further // packets according to this field. If it's not present, then all the logic defaults to hixie drafts // (see: http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76). lPos = lRequest.indexOf("Sec-WebSocket-Draft:"); if (lPos > 0) { lPos += 21; lDraft = lRequest.substring(lPos); lPos = lDraft.indexOf("\r\n"); lDraft = lDraft.substring(0, lPos); } // the following section implements the sec-key process in WebSocket // Draft 76 /* * To prove that the handshake was received, the server has to take * three pieces of information and combine them to form a response. The * first two pieces of information come from the |Sec-WebSocket-Key1| * and |Sec-WebSocket-Key2| fields in the client handshake. * * Sec-WebSocket-Key1: 18x 6]8vM;54 *(5: { U1]8 z [ 8 * Sec-WebSocket-Key2: 1_ tx7X d < nw 334J702) 7]o}` 0 * * For each of these fields, the server has to take the digits from the * value to obtain a number (in this case 1868545188 and 1733470270 * respectively), then divide that number by the number of spaces * characters in the value (in this case 12 and 10) to obtain a 32-bit * number (155712099 and 173347027). These two resulting numbers are * then used in the server handshake, as described below. */ lPos = lRequest.indexOf("Sec-WebSocket-Key1:"); if (lPos > 0) { lPos += 20; lSecKey1 = lRequest.substring(lPos); lPos = lSecKey1.indexOf("\r\n"); lSecKey1 = lSecKey1.substring(0, lPos); lSecNum1 = calcSecKeyNum(lSecKey1); // log.debug("Sec-WebSocket-Key1:" + secKey1 + " => " + secNum1); } lPos = lRequest.indexOf("Sec-WebSocket-Key2:"); if (lPos > 0) { lPos += 20; lSecKey2 = lRequest.substring(lPos); lPos = lSecKey2.indexOf("\r\n"); lSecKey2 = lSecKey2.substring(0, lPos); lSecNum2 = calcSecKeyNum(lSecKey2); // log.debug("Sec-WebSocket-Key2:" + secKey2 + " => " + secNum2); } /* * The third piece of information is given after the fields, in the last * eight bytes of the handshake, expressed here as they would be seen if * interpreted as ASCII: Tm[K T2u The concatenation of the number * obtained from processing the |Sec- WebSocket-Key1| field, expressed * as a big-endian 32 bit number, the number obtained from processing * the |Sec-WebSocket-Key2| field, again expressed as a big-endian 32 * bit number, and finally the eight bytes at the end of the handshake, * form a 128 bit string whose MD5 sum is then used by the server to * prove that it read the handshake. */ if (lSecNum1 != null && lSecNum2 != null) { // log.debug("Sec-WebSocket-Key3:" + new String(secKey3, "UTF-8")); BigInteger lSec1 = new BigInteger(lSecNum1.toString()); BigInteger lSec2 = new BigInteger(lSecNum2.toString()); // concatenate 3 parts secNum1 + secNum2 + secKey (16 Bytes) byte[] l128Bit = new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; byte[] lTmp; int lOfs; lTmp = lSec1.toByteArray(); int lIdx = lTmp.length; int lCnt = 0; while (lIdx > 0 && lCnt < 4) { lIdx--; lCnt++; l128Bit[4 - lCnt] = lTmp[lIdx]; } lTmp = lSec2.toByteArray(); lIdx = lTmp.length; lCnt = 0; while (lIdx > 0 && lCnt < 4) { lIdx--; lCnt++; l128Bit[8 - lCnt] = lTmp[lIdx]; } lTmp = lSecKey3; System.arraycopy(lSecKey3, 0, l128Bit, 8, 8); // build md5 sum of this new 128 byte string try { MessageDigest lMD = MessageDigest.getInstance("MD5"); lSecKeyResp = lMD.digest(l128Bit); } catch (Exception lEx) { // log.error("getMD5: " + ex.getMessage()); } } lRes.put(RequestHeader.WS_PATH, lPath); lRes.put(RequestHeader.WS_HOST, lHost); lRes.put(RequestHeader.WS_ORIGIN, lOrigin); lRes.put(RequestHeader.WS_LOCATION, lLocation); lRes.put(RequestHeader.WS_PROTOCOL, lSubProt); lRes.put(RequestHeader.WS_SECKEY1, lSecKey1); lRes.put(RequestHeader.WS_SECKEY2, lSecKey2); lRes.put(RequestHeader.WS_DRAFT, lDraft); lRes.put("isSecure", lIsSecure); lRes.put("secKeyResponse", lSecKeyResp); return lRes; } /** * Generates the response for the server to answer an initial client * request. This is performed on the server only as an answer to a client's * request - irrespective of if it is a Java or Browser Client. * * @param aRequest * @return */ public static byte[] generateS2CResponse(Map aRequest) { String lPolicyFileRequest = (String) aRequest.get("policy-file-request"); if (lPolicyFileRequest != null) { byte[] lBA; try { lBA = ("<cross-domain-policy>" + "<allow-access-from domain=\"*\" to-ports=\"*\" />" + "</cross-domain-policy>\n").getBytes("US-ASCII"); } catch (UnsupportedEncodingException lEx) { lBA = null; } return lBA; } // now that we have parsed the header send handshake... // since 0.9.0.0609 considering Sec-WebSocket-Key processing Boolean lIsSecure = (Boolean) aRequest.get("isSecure"); String lOrigin = (String) aRequest.get(RequestHeader.WS_ORIGIN); String lLocation = (String) aRequest.get(RequestHeader.WS_LOCATION); String lSubProt = (String) aRequest.get(RequestHeader.WS_PROTOCOL); String lRes = // since IETF draft 76 "WebSocket Protocol" not "Web Socket Protocol" // change implemented since v0.9.5.0701 "HTTP/1.1 101 Web" + (lIsSecure ? "" : " ") + "Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + (lSubProt != null ? (lIsSecure ? "Sec-" : "") + "WebSocket-Protocol: " + lSubProt + "\r\n" : "") + (lIsSecure ? "Sec-" : "") + "WebSocket-Origin: " + lOrigin + "\r\n" + (lIsSecure ? "Sec-" : "") + "WebSocket-Location: " + lLocation + "\r\n" + "\r\n"; byte[] lBA; try { lBA = lRes.getBytes("US-ASCII"); // if Sec-WebSocket-Keys are used send security response first if (lIsSecure) { byte[] lSecKey = (byte[]) aRequest.get("secKeyResponse"); byte[] lResult = new byte[lBA.length + lSecKey.length]; System.arraycopy(lBA, 0, lResult, 0, lBA.length); System.arraycopy(lSecKey, 0, lResult, lBA.length, lSecKey.length); return lResult; } else { return lBA; } } catch (UnsupportedEncodingException lEx) { return null; } } /** * Reads the handshake response from the server into an byte array. This is * used on clients only. The browser client implement that internally. * * @param aIS * @return */ public static byte[] readS2CResponse(InputStream aIS) { byte[] lBuff = new byte[MAX_HEADER_SIZE]; boolean lContinue = true; int lIdx = 0; int lB1 = 0, lB2 = 0, lB3 = 0, lB4 = 0; while (lContinue && lIdx < MAX_HEADER_SIZE) { int lIn; try { lIn = aIS.read(); if (lIn < 0) { return null; } } catch (IOException lIOEx) { return null; } // build mini queue to check for \r\n\r\n sequence in handshake lB1 = lB2; lB2 = lB3; lB3 = lB4; lB4 = lIn; lContinue = !(lB1 == 13 && lB2 == 10 && lB3 == 13 && lB4 == 10); lBuff[lIdx] = (byte) lIn; lIdx++; } byte[] lRes = new byte[lIdx]; System.arraycopy(lBuff, 0, lRes, 0, lIdx); return lRes; } /* * Parses the websocket handshake response from the server. This is * performed on Java Client only, the browsers implement that internally. * * @param aResp * * @return */ /** * * @param aResp * @return */ public static Map parseS2CResponse(byte[] aResp) { Map lRes = new FastMap(); String lResp = null; try { lResp = new String(aResp, "US-ASCII"); } catch (Exception lEx) { // TODO: add exception handling } return lRes; } /** * * @return */ public byte[] getHandshake() { String lPath = mURL.getPath(); String lHost = mURL.getHost(); mOrigin = "http://" + lHost; if ("".equals(lPath)) { lPath = "/"; } String lHandshake = "GET " + lPath + " HTTP/1.1\r\n" + "Host: " + lHost + "\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key2: " + mKey2 + "\r\n"; if (mProtocol != null) { lHandshake += "Sec-WebSocket-Protocol: " + mProtocol + "\r\n"; } if(mDraft != null) { lHandshake += "Sec-WebSocket-Draft: " + mDraft + "\r\n"; } lHandshake += "Upgrade: WebSocket\r\n" + "Sec-WebSocket-Key1: " + mKey1 + "\r\n" + "Origin: " + mOrigin + "\r\n" + "\r\n"; byte[] lHandshakeBytes = new byte[lHandshake.getBytes().length + 8]; System.arraycopy(lHandshake.getBytes(), 0, lHandshakeBytes, 0, lHandshake.getBytes().length); System.arraycopy(mKey3, 0, lHandshakeBytes, lHandshake.getBytes().length, 8); return lHandshakeBytes; } /** * * @param aBytes * @throws WebSocketException */ public void verifyServerResponse(byte[] aBytes) throws WebSocketException { if (!Arrays.equals(aBytes, mExpectedServerResponse)) { throw new WebSocketException("not a WebSocket Server"); } } /** * * @param aStatusLine * @throws WebSocketException */ public void verifyServerStatusLine(String aStatusLine) throws WebSocketException { int lStatusCode = Integer.valueOf(aStatusLine.substring(9, 12)); if (lStatusCode == 407) { throw new WebSocketException("connection failed: proxy authentication not supported"); } else if (lStatusCode == 404) { throw new WebSocketException("connection failed: 404 not found"); } else if (lStatusCode != 101) { throw new WebSocketException("connection failed: unknown status code " + lStatusCode); } } /** * * @param aHeaders * @throws WebSocketException */ public void verifyServerHandshakeHeaders(Map<String, String> aHeaders) throws WebSocketException { if (!aHeaders.get("Upgrade").equals("WebSocket")) { throw new WebSocketException("connection failed: missing header field in server handshake: Upgrade"); } else if (!aHeaders.get("Connection").equals("Upgrade")) { throw new WebSocketException("connection failed: missing header field in server handshake: Connection"); } else if (!aHeaders.get("Sec-WebSocket-Origin").equals(mOrigin)) { throw new WebSocketException("connection failed: missing header field in server handshake: Sec-WebSocket-Origin"); } else if(aHeaders.containsKey("Sec-WebSocket-Protocol") && (mProtocol.indexOf(aHeaders.get("Sec-WebSocket-Protocol")) == -1)) { // server returned sub protocol that wasn't proposed by the client? Illegal answer from server. throw new WebSocketException( "connection failed: invalid header field in server handshake: Sec-WebSocket-Protocol," + " expected one of : " + mProtocol + ", but got: " + aHeaders.get("Sec-WebSocket-Protocol")); } } private void generateKeys() { int lSpaces1 = rand(1, 12); int lSpaces2 = rand(1, 12); int lMax1 = Integer.MAX_VALUE / lSpaces1; int lMax2 = Integer.MAX_VALUE / lSpaces2; int lNumber1 = rand(0, lMax1); int lNumber2 = rand(0, lMax2); int lProduct1 = lNumber1 * lSpaces1; int lProduct2 = lNumber2 * lSpaces2; mKey1 = Integer.toString(lProduct1); mKey2 = Integer.toString(lProduct2); mKey1 = insertRandomCharacters(mKey1); mKey2 = insertRandomCharacters(mKey2); mKey1 = insertSpaces(mKey1, lSpaces1); mKey2 = insertSpaces(mKey2, lSpaces2); mKey3 = createRandomBytes(); ByteBuffer lBuffer = ByteBuffer.allocate(4); lBuffer.putInt(lNumber1); byte[] lNumber1Array = lBuffer.array(); lBuffer = ByteBuffer.allocate(4); lBuffer.putInt(lNumber2); byte[] lNumber2Array = lBuffer.array(); byte[] lChallenge = new byte[16]; System.arraycopy(lNumber1Array, 0, lChallenge, 0, 4); System.arraycopy(lNumber2Array, 0, lChallenge, 4, 4); System.arraycopy(mKey3, 0, lChallenge, 8, 8); mExpectedServerResponse = md5(lChallenge); } private String insertRandomCharacters(String aKey) { int lCount = rand(1, 12); char[] lRandomChars = new char[lCount]; int lRandCount = 0; while (lRandCount < lCount) { int lRand = (int) (Math.random() * 0x7e + 0x21); if (((0x21 < lRand) && (lRand < 0x2f)) || ((0x3a < lRand) && (lRand < 0x7e))) { lRandomChars[lRandCount] = (char) lRand; lRandCount += 1; } } for (int lIdx = 0; lIdx < lCount; lIdx++) { // updated by Alex 2010-10-25 after Roderik's hint: // int lSplit = rand(0, aKey.length()); int lSplit = rand(1, aKey.length() - 1); String lPart1 = aKey.substring(0, lSplit); String lPart2 = aKey.substring(lSplit); aKey = lPart1 + lRandomChars[lIdx] + lPart2; } return aKey; } private String insertSpaces(String aKey, int aSpaces) { for (int lIdx = 0; lIdx < aSpaces; lIdx++) { // updated by Alex 2010-10-25 after Roderik's hint: // int lSplit = rand(0, aKey.length()); int lSplit = rand(1, aKey.length() - 1); String lPart1 = aKey.substring(0, lSplit); String lPart2 = aKey.substring(lSplit); aKey = lPart1 + " " + lPart2; } return aKey; } private byte[] createRandomBytes() { byte[] lBytes = new byte[8]; for (int lIdx = 0; lIdx < 8; lIdx++) { lBytes[lIdx] = (byte) rand(0, 255); } return lBytes; } private byte[] md5(byte[] aBytes) { try { MessageDigest lMD = MessageDigest.getInstance("MD5"); return lMD.digest(aBytes); } catch (NoSuchAlgorithmException lEx) { return null; } } private int rand(int aMin, int aMax) { int lRand = (int) (Math.random() * aMax + aMin); return lRand; } }