/* * Copyright 2017 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.firebase.database.tubesock; import static com.google.common.base.Preconditions.checkState; import java.io.DataInputStream; import java.io.IOException; import java.io.OutputStream; import java.net.Socket; import java.net.URI; import java.net.UnknownHostException; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; import javax.net.SocketFactory; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; /** * This is the main class used to create a websocket connection. Create a new instance, set an event * handler, and then call connect(). Once the event handler's onOpen method has been called, call * send() on the websocket to transmit data. */ public class WebSocket { static final byte OPCODE_NONE = 0x0; static final byte OPCODE_TEXT = 0x1; static final byte OPCODE_BINARY = 0x2; static final byte OPCODE_CLOSE = 0x8; static final byte OPCODE_PING = 0x9; static final byte OPCODE_PONG = 0xA; private static final String THREAD_BASE_NAME = "TubeSock"; private static final AtomicInteger clientCount = new AtomicInteger(0); private static final Charset UTF8 = Charset.forName("UTF-8"); private static ThreadFactory threadFactory = Executors.defaultThreadFactory(); private static ThreadInitializer intializer = new ThreadInitializer() { @Override public void setName(Thread t, String name) { t.setName(name); } }; private final URI url; private final WebSocketReceiver receiver; private final WebSocketWriter writer; private final WebSocketHandshake handshake; private final int clientId = clientCount.incrementAndGet(); private Thread innerThread; private volatile State state = State.NONE; private volatile Socket socket = null; private WebSocketEventHandler eventHandler = null; /** * Create a websocket to connect to a given server * * @param url The URL of a websocket server */ public WebSocket(URI url) { this(url, null); } /** * Create a websocket to connect to a given server. Include protocol in websocket handshake * * @param url The URL of a websocket server * @param protocol The protocol to include in the handshake. If null, it will be omitted */ public WebSocket(URI url, String protocol) { this(url, protocol, null); } /** * Create a websocket to connect to a given server. Include the given protocol in the handshake, * as well as any extra HTTP headers specified. Useful if you would like to include a User-Agent * or other header * * @param url The URL of a websocket server * @param protocol The protocol to include in the handshake. If null, it will be omitted * @param extraHeaders Any extra HTTP headers to be included with the initial request. Pass null * if not extra headers are requested */ public WebSocket(URI url, String protocol, Map<String, String> extraHeaders) { this.url = url; handshake = new WebSocketHandshake(url, protocol, extraHeaders); receiver = new WebSocketReceiver(this); writer = new WebSocketWriter(this, THREAD_BASE_NAME, clientId); } static ThreadFactory getThreadFactory() { return threadFactory; } static ThreadInitializer getIntializer() { return intializer; } public static void setThreadFactory(ThreadFactory threadFactory, ThreadInitializer intializer) { WebSocket.threadFactory = threadFactory; WebSocket.intializer = intializer; } WebSocketEventHandler getEventHandler() { return this.eventHandler; } /** * Must be called before connect(). Set the handler for all websocket-related events. * * @param eventHandler The handler to be triggered with relevant events */ public void setEventHandler(WebSocketEventHandler eventHandler) { this.eventHandler = eventHandler; } /** * Start up the socket. This is non-blocking, it will fire up the threads used by the library and * then trigger the onOpen handler once the connection is established. */ public synchronized void connect() { if (state != State.NONE) { eventHandler.onError(new WebSocketException("connect() already called")); close(); return; } state = State.CONNECTING; start(); } private synchronized void start() { checkState(innerThread == null, "Inner thread already started"); innerThread = getThreadFactory() .newThread( new Runnable() { @Override public void run() { runReader(); } }); getIntializer().setName(innerThread, THREAD_BASE_NAME + "Reader-" + clientId); innerThread.start(); } /** * Send a TEXT message over the socket * * @param data The text payload to be sent */ public synchronized void send(String data) { send(OPCODE_TEXT, data.getBytes(UTF8)); } /** * Send a BINARY message over the socket * * @param data The binary payload to be sent */ public synchronized void send(byte[] data) { send(OPCODE_BINARY, data); } private synchronized void send(byte opcode, byte[] data) { if (state != State.CONNECTED) { // We might have been disconnected on another thread, just report an error eventHandler.onError(new WebSocketException("error while sending data: not connected")); } else { try { writer.send(opcode, true, data); } catch (IOException e) { eventHandler.onError(new WebSocketException("Failed to send frame", e)); close(); } } } synchronized void pong(byte[] data) { send(OPCODE_PONG, data); } void handleReceiverError(WebSocketException e) { eventHandler.onError(e); if (state == State.CONNECTED) { close(); } closeSocket(); } /** * Close down the socket. Will trigger the onClose handler if the socket has not been previously * closed. */ public synchronized void close() { // CSOFF: MissingSwitchDefaultCheck switch (state) { case NONE: state = State.DISCONNECTED; return; case CONNECTING: // don't wait for an established connection, just close the tcp socket closeSocket(); return; case CONNECTED: // This method also shuts down the writer // the socket will be closed once the ack for the close was received sendCloseHandshake(); return; case DISCONNECTING: return; // no-op; case DISCONNECTED: return; // No-op } // CSON: MissingSwitchDefaultCheck } void onCloseOpReceived() { closeSocket(); } private synchronized void closeSocket() { if (state == State.DISCONNECTED) { return; } receiver.stopit(); writer.stopIt(); if (socket != null) { try { socket.close(); } catch (IOException e) { throw new RuntimeException(e); } } state = State.DISCONNECTED; eventHandler.onClose(); } private void sendCloseHandshake() { try { state = State.DISCONNECTING; // Set the stop flag then queue up a message. This ensures that the writer thread // will wake up, and since we set the stop flag, it will exit its run loop. writer.stopIt(); writer.send(OPCODE_CLOSE, true, new byte[0]); } catch (IOException e) { eventHandler.onError(new WebSocketException("Failed to send close frame", e)); } } private Socket createSocket() { String scheme = url.getScheme(); String host = url.getHost(); int port = url.getPort(); Socket socket; if (scheme != null && scheme.equals("ws")) { if (port == -1) { port = 80; } try { socket = new Socket(host, port); } catch (UnknownHostException uhe) { throw new WebSocketException("unknown host: " + host, uhe); } catch (IOException ioe) { throw new WebSocketException("error while creating socket to " + url, ioe); } } else if (scheme != null && scheme.equals("wss")) { if (port == -1) { port = 443; } try { SocketFactory factory = SSLSocketFactory.getDefault(); SSLSocket sslSocket = (SSLSocket) factory.createSocket(host, port); // Ensure proper hostname verification, per // https://tersesystems.com/2014/03/23/fixing-hostname-verification/ // TODO: This code is different than Android. We should refactor it // into JvmPlatform. SSLParameters sslParams = new SSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslSocket.setSSLParameters(sslParams); socket = sslSocket; } catch (UnknownHostException uhe) { throw new WebSocketException("unknown host: " + host, uhe); } catch (IOException ioe) { throw new WebSocketException("error while creating secure socket to " + url, ioe); } } else { throw new WebSocketException("unsupported protocol: " + scheme); } return socket; } /** * Blocks until both threads exit. The actual close must be triggered separately. This is just a * convenience method to make sure everything shuts down, if desired. */ public void blockClose() throws InterruptedException { writer.waitForTermination(); Thread thread; synchronized (this) { if (innerThread == null) { return; } thread = innerThread; } thread.join(); } private void runReader() { try { Socket socket = createSocket(); synchronized (this) { WebSocket.this.socket = socket; if (WebSocket.this.state == WebSocket.State.DISCONNECTED) { // The connection has been closed while creating the socket, close it // immediately and // return try { WebSocket.this.socket.close(); } catch (IOException e) { throw new RuntimeException(e); } WebSocket.this.socket = null; return; } } DataInputStream input = new DataInputStream(socket.getInputStream()); OutputStream output = socket.getOutputStream(); output.write(handshake.getHandshake()); boolean handshakeComplete = false; int len = 1000; byte[] buffer = new byte[len]; int pos = 0; ArrayList<String> handshakeLines = new ArrayList<>(); while (!handshakeComplete) { int b = input.read(); if (b == -1) { throw new WebSocketException("Connection closed before handshake was complete"); } buffer[pos] = (byte) b; pos += 1; if (buffer[pos - 1] == 0x0A && buffer[pos - 2] == 0x0D) { String line = new String(buffer, UTF8); if (line.trim().equals("")) { handshakeComplete = true; } else { handshakeLines.add(line.trim()); } buffer = new byte[len]; pos = 0; } else if (pos == 1000) { // This really shouldn't happen, handshake lines are short, but just to be safe... String line = new String(buffer, UTF8); throw new WebSocketException("Unexpected long line in handshake: " + line); } } handshake.verifyServerStatusLine(handshakeLines.get(0)); handshakeLines.remove(0); HashMap<String, String> headers = new HashMap<>(); for (String line : handshakeLines) { String[] keyValue = line.split(": ", 2); headers.put(keyValue[0], keyValue[1]); } handshake.verifyServerHandshakeHeaders(headers); writer.setOutput(output); receiver.setInput(input); state = WebSocket.State.CONNECTED; writer.start(); eventHandler.onOpen(); receiver.run(); } catch (WebSocketException wse) { eventHandler.onError(wse); } catch (IOException ioe) { eventHandler.onError( new WebSocketException("error while connecting: " + ioe.getMessage(), ioe)); } finally { close(); } } private enum State { NONE, CONNECTING, CONNECTED, DISCONNECTING, DISCONNECTED } }