/*
* Copyright 2017 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.firebase.database.connection;
import com.google.firebase.database.connection.util.StringListReader;
import com.google.firebase.database.logging.LogWrapper;
import com.google.firebase.database.tubesock.WebSocket;
import com.google.firebase.database.tubesock.WebSocketEventHandler;
import com.google.firebase.database.tubesock.WebSocketException;
import com.google.firebase.database.tubesock.WebSocketMessage;
import com.google.firebase.database.util.JsonMapper;
import java.io.EOFException;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
class WebsocketConnection {
private static final long KEEP_ALIVE_TIMEOUT_MS = 45 * 1000; // 45 seconds
private static final long CONNECT_TIMEOUT_MS = 30 * 1000; // 30 seconds
private static final int MAX_FRAME_SIZE = 16384;
private static long connectionId = 0;
private final ConnectionContext connectionContext;
private final ScheduledExecutorService executorService;
private final LogWrapper logger;
private WSClient conn;
private boolean everConnected = false;
private boolean isClosed = false;
private long totalFrames = 0;
private StringListReader frameReader;
private Delegate delegate;
private ScheduledFuture<?> keepAlive;
private ScheduledFuture<?> connectTimeout;
public WebsocketConnection(
ConnectionContext connectionContext,
HostInfo hostInfo,
String optCachedHost,
Delegate delegate,
String optLastSessionId) {
this.connectionContext = connectionContext;
this.executorService = connectionContext.getExecutorService();
this.delegate = delegate;
long connId = connectionId++;
logger = new LogWrapper(connectionContext.getLogger(), "WebSocket", "ws_" + connId);
conn = createConnection(hostInfo, optCachedHost, optLastSessionId);
}
private static String[] splitIntoFrames(String src, int maxFrameSize) {
if (src.length() <= maxFrameSize) {
return new String[] {src};
} else {
ArrayList<String> segs = new ArrayList<>();
for (int i = 0; i < src.length(); i += maxFrameSize) {
int end = Math.min(i + maxFrameSize, src.length());
String seg = src.substring(i, end);
segs.add(seg);
}
return segs.toArray(new String[segs.size()]);
}
}
private WSClient createConnection(
HostInfo hostInfo, String optCachedHost, String optLastSessionId) {
String host = (optCachedHost != null) ? optCachedHost : hostInfo.getHost();
URI uri =
HostInfo.getConnectionUrl(
host, hostInfo.isSecure(), hostInfo.getNamespace(), optLastSessionId);
Map<String, String> extraHeaders = new HashMap<>();
extraHeaders.put("User-Agent", this.connectionContext.getUserAgent());
WebSocket ws = new WebSocket(uri, /*protocol=*/ null, extraHeaders);
WSClientTubesock client = new WSClientTubesock(ws);
return client;
}
public void open() {
conn.connect();
connectTimeout =
executorService.schedule(
new Runnable() {
@Override
public void run() {
closeIfNeverConnected();
}
},
CONNECT_TIMEOUT_MS,
TimeUnit.MILLISECONDS);
}
public void start() {
// No-op in java
}
public void close() {
if (logger.logsDebug()) {
logger.debug("websocket is being closed");
}
isClosed = true;
// Although true is passed for both of these, they each run on the same event loop, so
// they will
// never be running.
conn.close();
if (connectTimeout != null) {
connectTimeout.cancel(true);
}
if (keepAlive != null) {
keepAlive.cancel(true);
}
}
public void send(Map<String, Object> message) {
resetKeepAlive();
try {
String toSend = JsonMapper.serializeJson(message);
String[] segs = splitIntoFrames(toSend, MAX_FRAME_SIZE);
if (segs.length > 1) {
conn.send("" + segs.length);
}
for (int i = 0; i < segs.length; ++i) {
conn.send(segs[i]);
}
} catch (IOException e) {
logger.error("Failed to serialize message: " + message.toString(), e);
shutdown();
}
}
private void appendFrame(String message) {
frameReader.addString(message);
totalFrames -= 1;
if (totalFrames == 0) {
// Decode JSON
try {
frameReader.freeze();
Map<String, Object> decoded = JsonMapper.parseJson(frameReader.toString());
frameReader = null;
if (logger.logsDebug()) {
logger.debug("handleIncomingFrame complete frame: " + decoded);
}
delegate.onMessage(decoded);
} catch (IOException e) {
logger.error("Error parsing frame: " + frameReader.toString(), e);
close();
shutdown();
} catch (ClassCastException e) {
logger.error("Error parsing frame (cast error): " + frameReader.toString(), e);
close();
shutdown();
}
}
}
private void handleNewFrameCount(int numFrames) {
totalFrames = numFrames;
frameReader = new StringListReader();
if (logger.logsDebug()) {
logger.debug("HandleNewFrameCount: " + totalFrames);
}
}
private String extractFrameCount(String message) {
// TODO: The server is only supposed to send up to 9999 frames (i.e. length <= 4), but that
// isn't being enforced currently. So allowing larger frame counts (length <= 6).
// See https://app.asana.com/0/search/8688598998380/8237608042508
if (message.length() <= 6) {
try {
int frameCount = Integer.parseInt(message);
if (frameCount > 0) {
handleNewFrameCount(frameCount);
}
return null;
} catch (NumberFormatException e) {
// not a number, default to framecount 1
}
}
handleNewFrameCount(1);
return message;
}
private void handleIncomingFrame(String message) {
if (!isClosed) {
resetKeepAlive();
if (isBuffering()) {
appendFrame(message);
} else {
String remaining = extractFrameCount(message);
if (remaining != null) {
appendFrame(remaining);
}
}
}
}
private void resetKeepAlive() {
if (!isClosed) {
if (keepAlive != null) {
keepAlive.cancel(false);
if (logger.logsDebug()) {
logger.debug("Reset keepAlive. Remaining: " + keepAlive.getDelay(TimeUnit.MILLISECONDS));
}
} else {
if (logger.logsDebug()) {
logger.debug("Reset keepAlive");
}
}
keepAlive = executorService.schedule(nop(), KEEP_ALIVE_TIMEOUT_MS, TimeUnit.MILLISECONDS);
}
}
private Runnable nop() {
return new Runnable() {
@Override
public void run() {
if (conn != null) {
conn.send("0");
resetKeepAlive();
}
}
};
}
private boolean isBuffering() {
return frameReader != null;
}
private void onClosed() {
if (!isClosed) {
if (logger.logsDebug()) {
logger.debug("closing itself");
}
shutdown();
}
conn = null;
if (keepAlive != null) {
keepAlive.cancel(false);
}
}
private void shutdown() {
isClosed = true;
delegate.onDisconnect(everConnected);
}
// Close methods
private void closeIfNeverConnected() {
if (!everConnected && !isClosed) {
if (logger.logsDebug()) {
logger.debug("timed out on connect");
}
conn.close();
}
}
public interface Delegate {
void onMessage(Map<String, Object> message);
void onDisconnect(boolean wasEverConnected);
}
private interface WSClient {
void connect();
void close();
void send(String msg);
}
private class WSClientTubesock implements WSClient, WebSocketEventHandler {
private WebSocket ws;
private WSClientTubesock(WebSocket ws) {
this.ws = ws;
this.ws.setEventHandler(this);
}
@Override
public void onOpen() {
executorService.execute(
new Runnable() {
@Override
public void run() {
connectTimeout.cancel(false);
everConnected = true;
if (logger.logsDebug()) {
logger.debug("websocket opened");
}
resetKeepAlive();
}
});
}
@Override
public void onMessage(WebSocketMessage msg) {
final String str = msg.getText();
if (logger.logsDebug()) {
logger.debug("ws message: " + str);
}
executorService.execute(
new Runnable() {
@Override
public void run() {
handleIncomingFrame(str);
}
});
}
@Override
public void onClose() {
final String logMessage = "closed";
executorService.execute(
new Runnable() {
@Override
public void run() {
if (logger.logsDebug()) {
logger.debug(logMessage);
}
onClosed();
}
});
}
@Override
public void onError(final WebSocketException e) {
executorService.execute(
new Runnable() {
@Override
public void run() {
if (e.getCause() != null && e.getCause() instanceof EOFException) {
logger.debug("WebSocket reached EOF.");
} else {
logger.debug("WebSocket error.", e);
}
onClosed();
}
});
}
@Override
public void onLogMessage(String msg) {
if (logger.logsDebug()) {
logger.debug("Tubesock: " + msg);
}
}
@Override
public void send(String msg) {
ws.send(msg);
}
@Override
public void close() {
ws.close();
}
private void shutdown() {
ws.close();
try {
ws.blockClose();
} catch (InterruptedException e) {
logger.error("Interrupted while shutting down websocket threads", e);
}
}
@Override
public void connect() {
try {
ws.connect();
} catch (WebSocketException e) {
if (logger.logsDebug()) {
logger.debug("Error connecting", e);
}
shutdown();
}
}
}
}