/**
* JBoss, Home of Professional Open Source
* Copyright Red Hat, Inc., and individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.aerogear.simplepush.server.datastore;
import static org.jboss.aerogear.simplepush.util.ArgumentUtil.checkNotNull;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import org.jboss.aerogear.simplepush.protocol.Ack;
import org.jboss.aerogear.simplepush.protocol.impl.AckImpl;
import org.jboss.aerogear.simplepush.server.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A {@link DataStore} implementation that stores all information in memory.
*/
public class InMemoryDataStore implements DataStore {
private final ConcurrentMap<String, MutableChannel> channels = new ConcurrentHashMap<String, MutableChannel>();
private final ConcurrentMap<String, MutableChannel> endpoints = new ConcurrentHashMap<String, MutableChannel>();
private final ConcurrentMap<String, Set<Ack>> unacked = new ConcurrentHashMap<String, Set<Ack>>();
private final Logger logger = LoggerFactory.getLogger(InMemoryDataStore.class);
private byte[] salt;
@Override
public void savePrivateKeySalt(final byte[] salt) {
if (this.salt != null) {
this.salt = salt;
}
}
@Override
public byte[] getPrivateKeySalt() {
if (salt == null) {
return new byte[]{};
}
return salt;
}
@Override
public boolean saveChannel(final Channel ch) {
checkNotNull(ch, "ch");
final MutableChannel mutableChannel = new MutableChannel(ch);
final Channel previous = channels.putIfAbsent(ch.getChannelId(), mutableChannel);
endpoints.put(ch.getEndpointToken(), mutableChannel);
return previous == null;
}
private boolean removeChannel(final String channelId) {
checkNotNull(channelId, "channelId");
final Channel channel = channels.remove(channelId);
if (channel != null) {
endpoints.remove(endpoints.get(channel.getEndpointToken()));
}
return channel != null;
}
@Override
public Channel getChannel(final String channelId) throws ChannelNotFoundException {
checkNotNull(channelId, "channelId");
final Channel channel = channels.get(channelId);
if (channel == null) {
throw new ChannelNotFoundException("No Channel for [" + channelId + "] was found", channelId);
}
return channel;
}
@Override
public void removeChannels(final String uaid) {
checkNotNull(uaid, "uaid");
for (Channel channel : channels.values()) {
if (channel.getUAID().equals(uaid)) {
removeChannel(channel.getChannelId());
logger.info("Removing [" + channel.getChannelId() + "] for UserAgent [" + uaid + "]");
}
}
unacked.remove(uaid);
}
@Override
public void removeChannels(final Set<String> channelIds) {
checkNotNull(channelIds, "channelIds");
for (String channelId : channelIds) {
removeChannel(channelId);
logger.debug("Removing [" + channelId + "]");
}
}
@Override
public Set<String> getChannelIds(final String uaid) {
checkNotNull(uaid, "uaid");
final Set<String> channelIds = new HashSet<String>();
for (Channel channel : channels.values()) {
if (channel.getUAID().equals(uaid)) {
channelIds.add(channel.getChannelId());
}
}
return channelIds;
}
@Override
public String updateVersion(final String endpointToken, final long version) throws VersionException, ChannelNotFoundException {
final MutableChannel channel = endpoints.get(endpointToken);
if (channel == null) {
throw new ChannelNotFoundException("Could not find channel for endpointToken", endpointToken);
}
channel.updateVersion(version);
return channel.getChannelId();
}
@Override
public String saveUnacknowledged(final String channelId, final long version) throws ChannelNotFoundException {
checkNotNull(channelId, "channelId");
checkNotNull(version, "version");
final Channel channel = channels.get(channelId);
if (channel == null) {
throw new ChannelNotFoundException("Could not find channel", channelId);
}
final String uaid = channel.getUAID();
final Set<Ack> newAcks = Collections.newSetFromMap(new ConcurrentHashMap<Ack, Boolean>());
newAcks.add(new AckImpl(channelId, version));
while (true) {
final Set<Ack> currentAcks = unacked.get(uaid);
if (currentAcks == null) {
final Set<Ack> previous = unacked.putIfAbsent(uaid, newAcks);
if (previous != null) {
newAcks.addAll(previous);
if (unacked.replace(uaid, previous, newAcks)) {
break;
}
}
} else {
newAcks.addAll(currentAcks);
if (unacked.replace(uaid, currentAcks, newAcks)) {
break;
}
}
}
return uaid;
}
@Override
public Set<Ack> getUnacknowledged(final String uaid) {
checkNotNull(uaid, "uaid");
final Set<Ack> acks = unacked.get(uaid);
if (acks == null) {
return Collections.emptySet();
}
return Collections.unmodifiableSet(acks);
}
@Override
public Set<Ack> removeAcknowledged(final String uaid, final Set<Ack> acked) {
checkNotNull(uaid, "uaid");
checkNotNull(acked, "acked");
while (true) {
final Set<Ack> currentAcks = unacked.get(uaid);
if (currentAcks == null || currentAcks.isEmpty()) {
return Collections.emptySet();
}
final Set<Ack> newAcks = Collections.newSetFromMap(new ConcurrentHashMap<Ack, Boolean>());
boolean added = newAcks.addAll(currentAcks);
if (!added){
return newAcks;
}
boolean removed = newAcks.removeAll(acked);
if (removed) {
if (unacked.replace(uaid, currentAcks, newAcks)) {
return newAcks;
}
} else {
return newAcks;
}
}
}
/**
* A Channel implementation which has a mutable version and indended for
* usage with the InMemoryDataStore.
* This class uses a concurrent data structure to store and update the version.
*/
private static class MutableChannel implements Channel {
private final Channel delegate;
private final AtomicLong version;
public MutableChannel(final Channel delegate) {
this.delegate = delegate;
version = new AtomicLong(delegate.getVersion());
}
@Override
public String getUAID() {
return delegate.getUAID();
}
@Override
public String getChannelId() {
return delegate.getChannelId();
}
@Override
public long getVersion() {
return version.get();
}
public void updateVersion(final long newVersion) {
for (;;) {
final long currentVersion = version.get();
if (newVersion <= currentVersion) {
throw new VersionException("New version [" + newVersion + "] must be greater than current version [" + currentVersion + "]");
}
if (version.compareAndSet(currentVersion, newVersion)) {
break;
}
}
}
@Override
public String getEndpointToken() {
return delegate.getEndpointToken();
}
}
}