package org.whispersystems.textsecuregcm.storage; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import redis.clients.jedis.BinaryJedisPubSub; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; public class PubSubManager { private static final byte[] KEEPALIVE_CHANNEL = "KEEPALIVE".getBytes(); private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); private final SubscriptionListener baseListener = new SubscriptionListener(); private final Map<String, PubSubListener> listeners = new HashMap<>(); private final JedisPool jedisPool; private boolean subscribed = false; public PubSubManager(JedisPool jedisPool) { this.jedisPool = jedisPool; initializePubSubWorker(); waitForSubscription(); } public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) { String serializedAddress = address.serialize(); listeners.put(serializedAddress, listener); baseListener.subscribe(serializedAddress.getBytes()); } public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) { String serializedAddress = address.serialize(); if (listeners.get(serializedAddress) == listener) { listeners.remove(serializedAddress); baseListener.unsubscribe(serializedAddress.getBytes()); } } public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) { return publish(address.serialize().getBytes(), message); } private synchronized boolean publish(byte[] channel, PubSubMessage message) { try (Jedis jedis = jedisPool.getResource()) { return jedis.publish(channel, message.toByteArray()) != 0; } } private synchronized void waitForSubscription() { try { while (!subscribed) { wait(); } } catch (InterruptedException e) { throw new AssertionError(e); } } private void initializePubSubWorker() { new Thread("PubSubListener") { @Override public void run() { for (;;) { try (Jedis jedis = jedisPool.getResource()) { jedis.subscribe(baseListener, KEEPALIVE_CHANNEL); logger.warn("**** Unsubscribed from holding channel!!! ******"); } } } }.start(); new Thread("PubSubKeepAlive") { @Override public void run() { for (;;) { try { Thread.sleep(20000); publish(KEEPALIVE_CHANNEL, PubSubMessage.newBuilder() .setType(PubSubMessage.Type.KEEPALIVE) .build()); } catch (InterruptedException e) { throw new AssertionError(e); } } } }.start(); } private class SubscriptionListener extends BinaryJedisPubSub { @Override public void onMessage(byte[] channel, byte[] message) { try { PubSubListener listener; synchronized (PubSubManager.this) { listener = listeners.get(new String(channel)); } if (listener != null) { listener.onPubSubMessage(PubSubMessage.parseFrom(message)); } } catch (InvalidProtocolBufferException e) { logger.warn("Error parsing PubSub protobuf", e); } } @Override public void onPMessage(byte[] s, byte[] s2, byte[] s3) { logger.warn("Received PMessage!"); } @Override public void onSubscribe(byte[] channel, int count) { if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) { synchronized (PubSubManager.this) { subscribed = true; PubSubManager.this.notifyAll(); } } } @Override public void onUnsubscribe(byte[] s, int i) {} @Override public void onPUnsubscribe(byte[] s, int i) {} @Override public void onPSubscribe(byte[] s, int i) {} } }