package org.whispersystems.textsecuregcm.push; import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.RatioGauge; import com.codahale.metrics.SharedMetricRegistries; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map.Entry; import java.util.concurrent.TimeUnit; import static com.codahale.metrics.MetricRegistry.name; import io.dropwizard.lifecycle.Managed; public class ApnFallbackManager implements Managed, Runnable, DispatchChannel { private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class); public static final int FALLBACK_DURATION = 15; private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final Meter voipOneSuccess = metricRegistry.meter(name(ApnFallbackManager.class, "voip_one_success")); private static final Meter voipOneDelivery = metricRegistry.meter(name(ApnFallbackManager.class, "voip_one_failure")); private static final Histogram voipOneSuccessHistogram = metricRegistry.histogram(name(ApnFallbackManager.class, "voip_one_success_histogram")); static { metricRegistry.register(name(ApnFallbackManager.class, "voip_one_success_ratio"), new VoipRatioGauge(voipOneSuccess, voipOneDelivery)); } private final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); private final APNSender apnSender; private final PubSubManager pubSubManager; public ApnFallbackManager(APNSender apnSender, PubSubManager pubSubManager) { this.apnSender = apnSender; this.pubSubManager = pubSubManager; } public void schedule(final WebsocketAddress address, ApnFallbackTask task) { voipOneDelivery.mark(); if (taskQueue.put(address, task)) { pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); } } private void scheduleRetry(final WebsocketAddress address, ApnFallbackTask task) { if (taskQueue.putIfMissing(address, task)) { pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); } } public void cancel(WebsocketAddress address) { ApnFallbackTask task = taskQueue.remove(address); if (task != null) { pubSubManager.unsubscribe(new WebSocketConnectionInfo(address), this); voipOneSuccess.mark(); voipOneSuccessHistogram.update(System.currentTimeMillis() - task.getScheduledTime()); } } @Override public void start() throws Exception { new Thread(this).start(); } @Override public void stop() throws Exception { } @Override public void run() { while (true) { try { Entry<WebsocketAddress, ApnFallbackTask> taskEntry = taskQueue.get(); ApnFallbackTask task = taskEntry.getValue(); ApnMessage message; if (task.getAttempt() == 0) { message = new ApnMessage(task.getMessage(), task.getVoipApnId(), true, System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(FALLBACK_DURATION)); scheduleRetry(taskEntry.getKey(), new ApnFallbackTask(task.getApnId(), task.getVoipApnId(), task.getMessage(), task.getDelay(),1)); } else { message = new ApnMessage(task.getMessage(), task.getApnId(), false, ApnMessage.MAX_EXPIRATION); pubSubManager.unsubscribe(new WebSocketConnectionInfo(taskEntry.getKey()), this); } apnSender.sendMessage(message); } catch (Throwable e) { logger.warn("ApnFallbackThread", e); } } } @Override public void onDispatchMessage(String channel, byte[] message) { try { PubSubMessage notification = PubSubMessage.parseFrom(message); if (notification.getType().getNumber() == PubSubMessage.Type.CONNECTED_VALUE) { WebSocketConnectionInfo address = new WebSocketConnectionInfo(channel); cancel(address.getWebsocketAddress()); } else { logger.warn("Got strange pubsub type: " + notification.getType().getNumber()); } } catch (WebSocketConnectionInfo.FormattingException e) { logger.warn("Bad formatting?", e); } catch (InvalidProtocolBufferException e) { logger.warn("Bad protobuf", e); } } @Override public void onDispatchSubscribed(String channel) {} @Override public void onDispatchUnsubscribed(String channel) {} public static class ApnFallbackTask { private final long delay; private final long scheduledTime; private final String apnId; private final String voipApnId; private final ApnMessage message; private final int attempt; public ApnFallbackTask(String apnId, String voipApnId, ApnMessage message) { this(apnId, voipApnId, message, TimeUnit.SECONDS.toMillis(FALLBACK_DURATION), 0); } @VisibleForTesting public ApnFallbackTask(String apnId, String voipApnId, ApnMessage message, long delay, int attempt) { this.scheduledTime = System.currentTimeMillis(); this.delay = delay; this.apnId = apnId; this.voipApnId = voipApnId; this.message = message; this.attempt = attempt; } public String getApnId() { return apnId; } public String getVoipApnId() { return voipApnId; } public ApnMessage getMessage() { return message; } public long getScheduledTime() { return scheduledTime; } public long getExecutionTime() { return scheduledTime + delay; } public long getDelay() { return delay; } public int getAttempt() { return attempt; } } @VisibleForTesting public static class ApnFallbackTaskQueue { private final LinkedHashMap<WebsocketAddress, ApnFallbackTask> tasks = new LinkedHashMap<>(); public Entry<WebsocketAddress, ApnFallbackTask> get() { while (true) { long timeDelta; synchronized (tasks) { while (tasks.isEmpty()) Util.wait(tasks); Iterator<Entry<WebsocketAddress, ApnFallbackTask>> iterator = tasks.entrySet().iterator(); Entry<WebsocketAddress, ApnFallbackTask> nextTask = iterator.next(); timeDelta = nextTask.getValue().getExecutionTime() - System.currentTimeMillis(); if (timeDelta <= 0) { iterator.remove(); return nextTask; } } Util.sleep(timeDelta); } } public boolean put(WebsocketAddress address, ApnFallbackTask task) { synchronized (tasks) { ApnFallbackTask previous = tasks.put(address, task); tasks.notifyAll(); return previous == null; } } public boolean putIfMissing(WebsocketAddress address, ApnFallbackTask task) { synchronized (tasks) { if (tasks.containsKey(address)) return false; return put(address, task); } } public ApnFallbackTask remove(WebsocketAddress address) { synchronized (tasks) { return tasks.remove(address); } } } private static class VoipRatioGauge extends RatioGauge { private final Meter success; private final Meter attempts; private VoipRatioGauge(Meter success, Meter attempts) { this.success = success; this.attempts = attempts; } @Override protected Ratio getRatio() { return Ratio.of(success.getFiveMinuteRate(), attempts.getFiveMinuteRate()); } } }