// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package org.chromium.sdk.internal.websocket;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.net.InetSocketAddress;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.chromium.sdk.internal.transport.SocketWrapper;
/**
* A more or less straightforward implementation of WebSocket client-side handshake
* as defined in Internet-Draft of May 23, 2010.
* See http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00
* <p>Note that the standard was reworked completely. This implementation is obsolete. However
* it is still compatible with the current Chrome implementation.
*/
class Hybi00Handshake {
static void performHandshake(SocketWrapper socket, InetSocketAddress endpoint,
String resourceName, String origin, Random random) throws IOException {
HandshakeUtil.checkOriginString(origin);
OutputStream output = socket.getLoggableOutput().getOutputStream();
Writer outputWriter = new OutputStreamWriter(output, HandshakeUtil.UTF_8_CHARSET);
outputWriter.write("GET " + resourceName + " HTTP/1.1\r\n");
List<String> fields = HandshakeUtil.createHttpFields(endpoint);
fields.add("Upgrade: WebSocket");
fields.add("Origin: " + origin);
int port = endpoint.getPort();
String portSuffix = port == 80 ? "" : ":" + port;
fields.add("Host: " + endpoint.getHostName() + portSuffix);
WsKey key1 = new WsKey(random);
WsKey key2 = new WsKey(random);
fields.add("Sec-WebSocket-Key1: " + key1.getKeySocketField());
fields.add("Sec-WebSocket-Key2: " + key2.getKeySocketField());
Collections.shuffle(fields, random);
for (String field : fields) {
outputWriter.write(field);
outputWriter.write("\r\n");
}
outputWriter.write("\r\n");
byte[] key3 = new byte[8];
random.nextBytes(key3);
outputWriter.flush();
output.write(key3);
output.flush();
byte[] expectedMd5Bytes;
{
// Challenge.
ByteArrayOutputStream challengeBytes = new ByteArrayOutputStream(16);
writeIntBigEndian(key1.getNumber(), challengeBytes);
writeIntBigEndian(key2.getNumber(), challengeBytes);
challengeBytes.write(key3);
MessageDigest digest;
try {
digest = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
expectedMd5Bytes = digest.digest(challengeBytes.toByteArray());
}
final InputStream input = socket.getLoggableInput().getInputStream();
HandshakeUtil.LineReader lineReader = HandshakeUtil.createLineReader(input);
HandshakeUtil.HttpResponse httpResponse = HandshakeUtil.readHttpResponse(lineReader);
if (httpResponse.getCode() != 101) {
throw new IOException("Unexpected response code " + httpResponse.getCode());
}
Map<String, String> responseFields = httpResponse.getFields();
if (responseFields.size() != EXPECTED_FIELDS.size()) {
throw new IOException("Malformed response");
}
if (!responseFields.keySet().containsAll(EXPECTED_FIELDS)) {
throw new IOException("Malformed response");
}
if (!"WebSocket".equals(responseFields.get("upgrade"))) {
throw new IOException("Malformed response");
}
if (!"upgrade".equalsIgnoreCase(responseFields.get("connection"))) {
throw new IOException("Malformed response");
}
if (!origin.equals(responseFields.get("sec-websocket-origin"))) {
throw new IOException("Malformed response");
}
String expectedUrl = createUrl(endpoint, resourceName, false);
if (!expectedUrl.equals(responseFields.get("sec-websocket-location"))) {
throw new IOException("Malformed response: unexpected sec-websocket-location");
}
{
// Challenge response.
byte[] actualMd5Bytes = new byte[16];
{
int readPos = 0;
while (readPos < actualMd5Bytes.length) {
int readRes = input.read(actualMd5Bytes, readPos, actualMd5Bytes.length - readPos);
if (readRes == -1) {
throw new IOException("End of stream");
}
readPos += readRes;
}
}
if (!Arrays.equals(expectedMd5Bytes, actualMd5Bytes)) {
throw new IOException("Wrong challenge response: expected=" +
Arrays.toString(expectedMd5Bytes) + " recieved=" + Arrays.toString(actualMd5Bytes));
}
}
}
private static String createUrl(InetSocketAddress endpoint, String resourceName,
boolean secure) {
boolean needPort;
if (secure) {
needPort = endpoint.getPort() != 443;
} else {
needPort = endpoint.getPort() != 80;
}
return (secure ? "wss://" : "ws://") +
endpoint.getHostName() +
(needPort ? ":" + endpoint.getPort() : "") +
resourceName;
}
private static void writeIntBigEndian(long value, OutputStream output) throws IOException {
output.write((byte) ((value & 0xFF000000L) >> (3 * 8)));
output.write((byte) ((value & 0xFF0000L) >> (2 * 8)));
output.write((byte) ((value & 0xFF00L) >> (1 * 8)));
output.write((byte) ((value & 0xFFL)));
}
private static final Set<String> EXPECTED_FIELDS = new HashSet<String>(Arrays.asList(
"upgrade",
"connection",
"sec-websocket-origin",
"sec-websocket-location"
));
private static class WsKey {
private static final long SPEC_MAX = 4294967295l;
private final long resNumber;
private final String keyString;
WsKey(Random random) {
int spaces = random.nextInt(12) + 1;
long max = SPEC_MAX / spaces;
long number = Math.abs(random.nextLong()) % (max + 1);
resNumber = number;
long product = number * spaces;
assert(product <= SPEC_MAX);
String productStr = Long.toString(product);
List<Byte> keyBytes = new ArrayList<Byte>(40);
keyBytes.addAll(Collections.nCopies(productStr.length(), (byte) '1'));
int stuffByteNumber = random.nextInt(12) + 1;
for (int i = 0; i < stuffByteNumber; i++) {
keyBytes.add(StuffBytes.getByte(random));
}
Collections.shuffle(keyBytes, random);
keyBytes.subList(0, keyBytes.size() - 1).addAll(Collections.nCopies(spaces, (byte) ' '));
Collections.shuffle(keyBytes.subList(1, keyBytes.size() - 1), random);
byte[] resultBytes = new byte[keyBytes.size()];
int strPos = 0;
for (int i = 0; i < resultBytes.length; i++) {
byte b = keyBytes.get(i);
if (b == (byte) '1') {
b = (byte) productStr.charAt(strPos);
strPos++;
}
resultBytes[i] = b;
}
assert(strPos == productStr.length());
keyString = new String(resultBytes, HandshakeUtil.ASCII_CHARSET);
}
String getKeySocketField() {
return keyString;
}
long getNumber() {
return resNumber;
}
private static class StuffBytes {
private static byte RANGE_1_BEGIN = 0x21;
private static byte RANGE_1_END = 0x2F + 1;
private static byte RANGE_2_BEGIN = 0x3A;
private static byte RANGE_2_END = 0x7E + 1;
private static int RANDOM_RANGE_1 = RANGE_1_END - RANGE_1_BEGIN;
private static int RANDOM_RANGE = RANDOM_RANGE_1 + RANGE_2_END - RANGE_2_BEGIN;
private static byte getByte(Random random) {
int i = random.nextInt(RANDOM_RANGE);
if (i < RANDOM_RANGE_1) {
return (byte) (i + RANGE_1_BEGIN);
} else {
return (byte) (i + - RANDOM_RANGE_1 + RANGE_2_BEGIN);
}
}
}
}
}