/** * JBoss, Home of Professional Open Source Copyright Red Hat, Inc., and individual contributors * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the * License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific * language governing permissions and limitations under the License. */ package org.jboss.aerogear.simplepush.server.datastore; import java.nio.charset.Charset; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.jboss.aerogear.simplepush.protocol.Ack; import org.jboss.aerogear.simplepush.protocol.impl.AckImpl; import org.jboss.aerogear.simplepush.server.Channel; import org.jboss.aerogear.simplepush.server.DefaultChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; import redis.clients.jedis.JedisPoolConfig; import redis.clients.jedis.Transaction; /** * DataStore that uses a Redis database for storage. */ public class RedisDataStore implements DataStore { private final static String CHID_LOOKUP_KEY_PREFIX = "chid:lookup:"; private final static String UAID_LOOKUP_KEY_PREFIX = "uaid:lookup:"; private final static String TOKEN_LOOKUP_KEY_PREFIX = "token:lookup:"; private final static String ACK_LOOKUP_KEY_PREFIX = "ack:"; private final static String ACKS_LOOKUP_KEY_PREFIX = "acks:"; private final static String TOKEN_KEY = "token"; private final static String UAID_KEY = "uaid"; private final Logger logger = LoggerFactory.getLogger(RedisDataStore.class); private final static Charset UTF_8 = Charset.forName("UTF-8"); private final JedisPool jedisPool; public RedisDataStore(final String host, final int port) { jedisPool = new JedisPool(new JedisPoolConfig(), host, port); } @Override public void savePrivateKeySalt(final byte[] salt) { final Jedis jedis = jedisPool.getResource(); try { jedis.set("salt", new String(salt, UTF_8)); } finally { jedisPool.returnResource(jedis); } } @Override public byte[] getPrivateKeySalt() { final Jedis jedis = jedisPool.getResource(); try { final String salt = jedis.get("salt"); return salt != null ? salt.getBytes(UTF_8) : new byte[]{}; } finally { jedisPool.returnResource(jedis); } } @Override public boolean saveChannel(final Channel channel) { final Jedis jedis = jedisPool.getResource(); try { final String uaid = channel.getUAID(); final String chid = channel.getChannelId(); if (jedis.sismember(uaidLookupKey(uaid), chid)) { return false; } final String endpointToken = channel.getEndpointToken(); final Transaction tx = jedis.multi(); tx.set(endpointToken, Long.toString(channel.getVersion())); tx.set(tokenLookupKey(endpointToken), chid); tx.hmset(chidLookupKey(chid), mapOf(endpointToken, uaid)); tx.sadd(uaidLookupKey(uaid), chid); tx.exec(); return true; } finally { jedisPool.returnResource(jedis); } } private Map<String, String> mapOf(final String endpointToken, final String uaid) { final Map<String, String> map = new HashMap<String, String>(2); map.put(TOKEN_KEY, endpointToken); map.put(UAID_KEY, uaid); return map; } private void removeChannel(final String channelId) { final Jedis jedis = jedisPool.getResource(); try { final Channel channel = getChannel(channelId); final String endpointToken = channel.getEndpointToken(); final Transaction tx = jedis.multi(); tx.del(endpointToken); tx.del(chidLookupKey(channelId)); tx.del(tokenLookupKey(endpointToken)); tx.srem(uaidLookupKey(channel.getUAID()), channelId); tx.exec(); } catch (final ChannelNotFoundException e) { logger.debug("ChannelId [" + channelId + "] was not found"); } finally { jedisPool.returnResource(jedis); } } @Override public void removeChannels(final Set<String> channelIds) { for (String channelId : channelIds) { removeChannel(channelId); } } @Override public Channel getChannel(final String channelId) throws ChannelNotFoundException { final Jedis jedis = jedisPool.getResource(); try { final List<String> endpointTokenAndUaid = jedis.hmget(chidLookupKey(channelId), TOKEN_KEY, UAID_KEY); if (!endpointTokenAndUaid.isEmpty()) { final String endpointToken = endpointTokenAndUaid.get(0); final String uaid = endpointTokenAndUaid.get(1); if (endpointToken == null || uaid == null) { throw channelNotFoundException(channelId); } return new DefaultChannel(uaid, channelId, Long.valueOf(jedis.get(endpointToken)), endpointToken); } throw channelNotFoundException(channelId); } finally { jedisPool.returnResource(jedis); } } @Override public Set<String> getChannelIds(final String uaid) { final Jedis jedis = jedisPool.getResource(); try { return jedis.smembers(uaidLookupKey(uaid)); } finally { jedisPool.returnResource(jedis); } } @Override public void removeChannels(final String uaid) { //TODO: This is not efficient. Can we do the clean up in some other way. // This is only called from the reaper thread, perhaps we can use an expiration // like Mozilla does or something equivalent. final Jedis jedis = jedisPool.getResource(); try { for (String channelId : getChannelIds(uaid)) { removeChannel(channelId); } jedis.del(uaidLookupKey(uaid)); } finally { jedisPool.returnResource(jedis); } } @Override public String updateVersion(final String endpointToken, final long newVersion) throws VersionException, ChannelNotFoundException { final Jedis jedis = jedisPool.getResource(); try { jedis.watch(endpointToken); final String versionString = jedis.get(endpointToken); if (versionString == null) { throw channelNotFoundException(endpointToken); } final long currentVersion = Long.valueOf(versionString); if (newVersion <= currentVersion) { throw new VersionException("version [" + newVersion + "] must be greater than the current version [" + currentVersion + "]"); } final Transaction tx = jedis.multi(); tx.set(endpointToken, String.valueOf(newVersion)); tx.exec(); logger.debug(tokenLookupKey(endpointToken)); return jedis.get(tokenLookupKey(endpointToken)); } finally { jedisPool.returnResource(jedis); } } @Override public String saveUnacknowledged(final String channelId, final long version) { final Jedis jedis = jedisPool.getResource(); try { jedis.set(ackLookupKey(channelId), Long.toString(version)); final List<String> hashValues = jedis.hmget(chidLookupKey(channelId), UAID_KEY); final String uaid = hashValues.get(0); jedis.sadd(acksLookupKey(uaid), channelId); return uaid; } finally { jedisPool.returnResource(jedis); } } @Override public Set<Ack> getUnacknowledged(final String uaid) { final Jedis jedis = jedisPool.getResource(); try { final Set<String> unacks = jedis.smembers(acksLookupKey(uaid)); if (unacks.isEmpty()) { return Collections.emptySet(); } final Set<Ack> acks = new HashSet<Ack>(unacks.size()); for (String channelId : unacks) { acks.add(new AckImpl(channelId, Long.valueOf(jedis.get(ackLookupKey(channelId))))); } return acks; } finally { jedisPool.returnResource(jedis); } } @Override public Set<Ack> removeAcknowledged(final String uaid, final Set<Ack> acks) { final Jedis jedis = jedisPool.getResource(); try { for (Ack ack : acks) { jedis.del(ackLookupKey(ack.getChannelId())); jedis.srem(acksLookupKey(uaid), ack.getChannelId()); } return getUnacknowledged(uaid); } finally { jedisPool.returnResource(jedis); } } private static String chidLookupKey(final String channelId) { return CHID_LOOKUP_KEY_PREFIX + channelId; } private static String tokenLookupKey(final String endpointToken) { return TOKEN_LOOKUP_KEY_PREFIX + endpointToken; } private static String uaidLookupKey(final String uaid) { return UAID_LOOKUP_KEY_PREFIX + uaid; } private static String ackLookupKey(final String channelId) { return ACK_LOOKUP_KEY_PREFIX + channelId; } private static String acksLookupKey(final String uaid) { return ACKS_LOOKUP_KEY_PREFIX + uaid; } private static ChannelNotFoundException channelNotFoundException(final String channelId) { return new ChannelNotFoundException("Could not find channel [" + channelId + "]", channelId); } }