package org.whispersystems.textsecuregcm.websocket; import com.codahale.metrics.Histogram; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.google.common.base.Optional; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.TransientPushFailureException; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import javax.annotation.Nonnull; import javax.annotation.Nullable; import javax.ws.rs.WebApplicationException; import java.io.IOException; import java.util.Iterator; import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; public class WebSocketConnection implements DispatchChannel { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); private final ReceiptSender receiptSender; private final PushSender pushSender; private final MessagesManager messagesManager; private final Account account; private final Device device; private final WebSocketClient client; public WebSocketConnection(PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, Account account, Device device, WebSocketClient client) { this.pushSender = pushSender; this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.account = account; this.device = device; this.client = client; } @Override public void onDispatchMessage(String channel, byte[] message) { try { PubSubMessage pubSubMessage = PubSubMessage.parseFrom(message); switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.<Long>absent(), false); break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); } } catch (InvalidProtocolBufferException e) { logger.warn("Protobuf parse error", e); } } @Override public void onDispatchUnsubscribed(String channel) { client.close(1000, "OK"); } public void onDispatchSubscribed(String channel) { processStoredMessages(); } private void sendMessage(final Envelope message, final Optional<Long> storedMessageId, final boolean requery) { try { EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, device.getSignalingKey()); Optional<byte[]> body = Optional.fromNullable(encryptedMessage.toByteArray()); ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/api/v1/message", null, body); Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() { @Override public void onSuccess(@Nullable WebSocketResponseMessage response) { boolean isReceipt = message.getType() == Envelope.Type.RECEIPT; if (isSuccessResponse(response) && !isReceipt) { messageTime.update(System.currentTimeMillis() - message.getTimestamp()); } if (isSuccessResponse(response)) { if (storedMessageId.isPresent()) messagesManager.delete(account.getNumber(), storedMessageId.get()); if (!isReceipt) sendDeliveryReceiptFor(message); if (requery) processStoredMessages(); } else if (!isSuccessResponse(response) && !storedMessageId.isPresent()) { requeueMessage(message); } } @Override public void onFailure(@Nonnull Throwable throwable) { if (!storedMessageId.isPresent()) requeueMessage(message); } private boolean isSuccessResponse(WebSocketResponseMessage response) { return response != null && response.getStatus() >= 200 && response.getStatus() < 300; } }); } catch (CryptoEncodingException e) { logger.warn("Bad signaling key", e); } } private void requeueMessage(Envelope message) { int queueDepth = pushSender.getWebSocketSender().queueMessage(account, device, message); boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT; try { pushSender.sendQueuedNotification(account, device, queueDepth, fallback); } catch (NotPushRegisteredException | TransientPushFailureException e) { logger.warn("requeueMessage", e); } } private void sendDeliveryReceiptFor(Envelope message) { try { receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp(), message.hasRelay() ? Optional.of(message.getRelay()) : Optional.<String>absent()); } catch (NoSuchUserException | NotPushRegisteredException e) { logger.info("No longer registered " + e.getMessage()); } catch(IOException | TransientPushFailureException e) { logger.warn("Something wrong while sending receipt", e); } catch (WebApplicationException e) { logger.warn("Bad federated response for receipt: " + e.getResponse().getStatus()); } } private void processStoredMessages() { OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), device.getId()); Iterator<OutgoingMessageEntity> iterator = messages.getMessages().iterator(); while (iterator.hasNext()) { OutgoingMessageEntity message = iterator.next(); Envelope.Builder builder = Envelope.newBuilder() .setType(Envelope.Type.valueOf(message.getType())) .setSourceDevice(message.getSourceDevice()) .setSource(message.getSource()) .setTimestamp(message.getTimestamp()); if (message.getMessage() != null) { builder.setLegacyMessage(ByteString.copyFrom(message.getMessage())); } if (message.getContent() != null) { builder.setContent(ByteString.copyFrom(message.getContent())); } if (message.getRelay() != null && !message.getRelay().isEmpty()) { builder.setRelay(message.getRelay()); } sendMessage(builder.build(), Optional.of(message.getId()), !iterator.hasNext() && messages.hasMore()); } if (!messages.hasMore()) { client.sendRequest("PUT", "/api/v1/queue/empty", null, Optional.<byte[]>absent()); } } }