// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. package org.chromium.sdk.internal.websocket; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Random; import javax.xml.bind.DatatypeConverter; import org.chromium.sdk.internal.websocket.ManualLoggingSocketWrapper.LoggableInput; import org.chromium.sdk.util.BasicUtil; /** * WebSocket connection handshake. * @see http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 */ class Hybi17Handshake { static Result performHandshake(ManualLoggingSocketWrapper socket, InetSocketAddress endpoint, String resourceName, Random random) throws IOException { final ManualLoggingSocketWrapper.LoggableInput input = socket.getLoggableInput(); ManualLoggingSocketWrapper.LoggableOutput output = socket.getLoggableOutput(); writeHttpLine(output, "GET " + resourceName + " HTTP/1.1"); List<String> headerFields = HandshakeUtil.createHttpFields(endpoint); headerFields.add("Upgrade: websocket"); headerFields.add("Host: " + endpoint.getHostName()); byte[] secKeyBytes = new byte[16]; random.nextBytes(secKeyBytes); String secKeyString = DatatypeConverter.printBase64Binary(secKeyBytes); headerFields.add("Sec-WebSocket-Key: " + secKeyString); headerFields.add("Sec-WebSocket-Version: 13"); Collections.shuffle(headerFields, random); for (String field : headerFields) { writeHttpLine(output, field); } writeHttpLine(output, ""); HandshakeUtil.LineReader lineReader = new HandshakeUtil.LineReader() { @Override byte[] readUpTo0x0D0A() throws IOException { ByteBuffer buffer = input.readUpTo0x0D0A(); byte[] result = new byte[buffer.limit()]; buffer.get(result); return result; } }; HandshakeUtil.HttpResponse httpResponse = HandshakeUtil.readHttpResponse(lineReader); if (httpResponse.getCode() != 101) { return processResult(input, httpResponse); } Map<String, String> responseFields = httpResponse.getFields(); if (!"websocket".equalsIgnoreCase(responseFields.get("upgrade"))) { throw new IOException("Malformed response"); } if (!"upgrade".equalsIgnoreCase(responseFields.get("connection"))) { throw new IOException("Malformed response"); } if (responseFields.get("sec-websocket-extensions") != null) { throw new IOException("Malformed response"); } if (responseFields.get("sec-websocket-protocol") != null) { throw new IOException("Malformed response"); } String secAcceptString = responseFields.get("sec-websocket-accept"); if (secAcceptString == null) { throw new IOException("Malformed response"); } String expectedConcatenation = secKeyString + GUID; byte[] expectedAcceptSha1; { MessageDigest digest; try { digest = MessageDigest.getInstance("SHA-1"); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } expectedAcceptSha1 = digest.digest(expectedConcatenation.getBytes()); } String expectedAcceptString = DatatypeConverter.printBase64Binary(expectedAcceptSha1); if (!BasicUtil.eq(expectedAcceptString, secAcceptString)) { throw new IOException("Malformed response"); } return CONNECTED_RESULT; } static abstract class Result { abstract <R> R accept(Visitor<R> visitor); interface Visitor<R> { R visitConnected(); R visitUnknownError(Exception exception); R visitErrorMessage(int code, String errorName, String text); } static Result createError(final Exception exception) { return new Result() { @Override <R> R accept(Visitor<R> visitor) { return visitor.visitUnknownError(exception); } }; } } private static final Result CONNECTED_RESULT = new Result() { @Override <R> R accept(Visitor<R> visitor) { return visitor.visitConnected(); } }; private static Result processResult(LoggableInput input, final HandshakeUtil.HttpResponse httpResponse) throws IOException { Map<String, String> fields = httpResponse.getFields(); String contentType = fields.get("content-type"); String contentLength = fields.get("content-length"); if ("text/html".equals(contentType) && contentLength != null) { int length; try { length = Integer.parseInt(contentLength); } catch (NumberFormatException e) { return Result.createError(new Exception("Failed to parse context-length field", e)); } byte[] response = input.readBytes(length); final String contentText = new String(response, HandshakeUtil.ASCII_CHARSET); return new Result() { @Override <R> R accept(Visitor<R> visitor) { return visitor.visitErrorMessage(httpResponse.getCode(), httpResponse.getReasonPhrase(), contentText); } }; } return Result.createError(new Exception("Error response: " + httpResponse.getCode() + " " + httpResponse.getReasonPhrase())); } private static void writeHttpLine(ManualLoggingSocketWrapper.LoggableOutput output, String line) throws IOException { output.writeAsciiString(line + "\r\n"); } private static final String GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; }