/* (c) 2014 Open Source Geospatial Foundation - all rights reserved
* (c) 2014 Boundless
* This code is licensed under the GPL 2.0 license, available at the root
* application directory.
*/
package org.geoserver.cluster.hazelcast;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.geoserver.catalog.Catalog;
import org.geoserver.catalog.Info;
import org.geoserver.config.util.XStreamPersister;
import org.geoserver.config.util.XStreamPersisterFactory;
import org.geoserver.util.CacheProvider;
import org.geoserver.util.DefaultCacheProvider;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Throwables;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheStats;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.core.IMap;
/**
* {@link CacheProvider} for a cluster configuration. Looked up by interface by
* {@link DefaultCacheProvider#findProvider()} hence shall be declared in
* {@code applicationContext.xml}.
*
*/
public class HzCacheProvider implements CacheProvider {
private static final long DEFAULT_TTL = 5;
private static final TimeUnit DEFAULT_TTL_UNIT = TimeUnit.MINUTES;
private Map<String, Cache<?, ?>> inUse = Maps.newConcurrentMap();
private XStreamPersisterFactory serializationFactory;
public HzCacheProvider(XStreamPersisterFactory serializationFactory) {
this.serializationFactory = serializationFactory;
}
@SuppressWarnings("unchecked")
@Override
public synchronized <K extends Serializable, V extends Serializable> Cache<K, V> getCache(
final String cacheName) {
Cache<K, V> distributedCache = (Cache<K, V>) inUse.get(cacheName);
if (distributedCache == null) {
// distributedCache = new NullCache<K, V>();
if ("catalog".equals(cacheName)) {
distributedCache = (Cache<K, V>) new HzCatalogCache(cacheName, DEFAULT_TTL,
DEFAULT_TTL_UNIT, serializationFactory);
} else {
distributedCache = new HzCache<K, V>(cacheName, DEFAULT_TTL, DEFAULT_TTL_UNIT);
}
inUse.put(cacheName, distributedCache);
}
return distributedCache;
}
private static final class HzCache<K extends Serializable, V extends Serializable> implements
Cache<K, V> {
private IMap<K, V> hzMap;
private final long ttl;
private final TimeUnit timeunit;
private final String mapName;
public HzCache(String mapName, long ttl, TimeUnit ttlUnit) {
this.mapName = mapName;
this.hzMap = null;
this.ttl = ttl;
this.timeunit = ttlUnit;
}
private boolean available() {
if (hzMap == null) {
if (HzCluster.getInstanceIfAvailable().isPresent()) {
HzCluster hzCluster = HzCluster.getInstanceIfAvailable().get();
HazelcastInstance hazelcastInstance = hzCluster.getHz();
hzMap = hazelcastInstance.getMap(mapName);
}
}
return hzMap != null;
}
@Override
public V getIfPresent(Object key) {
if (available()) {
return hzMap.get(key);
}
return null;
}
@Override
public V get(K key, Callable<? extends V> valueLoader) throws ExecutionException {
V value = getIfPresent(key);
if (value == null) {
try {
value = valueLoader.call();
put(key, value);
} catch (Exception e) {
throw new ExecutionException(e);
}
}
return value;
}
@Override
public ImmutableMap<K, V> getAllPresent(Iterable<?> keys) {
if (available()) {
Set<K> set = new HashSet<K>();
for (Object k : keys) {
set.add((K) k);
}
Map<K, V> allPresent = hzMap.getAll(set);
return ImmutableMap.copyOf(allPresent);
}
return ImmutableMap.of();
}
@Override
public void put(K key, V value) {
if (available()) {
hzMap.putTransient(key, value, ttl, timeunit);
}
}
@SuppressWarnings("unchecked")
@Override
public void invalidate(Object key) {
if (available()) {
hzMap.remove((K) key);
}
}
@SuppressWarnings("unchecked")
@Override
public void invalidateAll(Iterable<?> keys) {
if (available()) {
for (Object k : keys) {
hzMap.remove((K) k);
}
}
}
@Override
public void invalidateAll() {
if (available()) {
hzMap.clear();
}
}
@Override
public long size() {
return available() ? hzMap.size() : 0L;
}
@Override
public CacheStats stats() {
throw new UnsupportedOperationException();
}
@Override
public ConcurrentMap<K, V> asMap() {
if (available()) {
return hzMap;
}
return Maps.newConcurrentMap();
}
@Override
public void cleanUp() {
//
}
@Override
public void putAll(Map<? extends K, ? extends V> m) {
hzMap.putAll(m);
}
}
private static final class HzCatalogCache implements Cache<String, Info> {
private IMap<String, byte[]> hzMap;
private final long ttl;
private final TimeUnit timeunit;
private final String mapName;
private XStreamPersisterFactory serializationFactory;
private XStreamPersister persister;
public HzCatalogCache(String mapName, long ttl, TimeUnit ttlUnit,
XStreamPersisterFactory serializationFactory2) {
this.mapName = mapName;
this.ttl = ttl;
this.timeunit = ttlUnit;
this.serializationFactory = serializationFactory2;
this.hzMap = null;
}
private boolean available() {
Optional<HzCluster> cluster = HzCluster.getInstanceIfAvailable();
if (!cluster.isPresent()) {
return false;
}
HzCluster hzCluster = cluster.get();
if (hzMap == null && hzCluster.isRunning()) {
HazelcastInstance hazelcastInstance = hzCluster.getHz();
Catalog catalog = hzCluster.getRawCatalog();
hzMap = hazelcastInstance.getMap(mapName);
persister = serializationFactory.createXMLPersister();
persister.setCatalog(catalog);
}
return hzMap != null && hzCluster.isRunning();
}
@Override
public Info getIfPresent(Object key) {
Info info = null;
if (available()) {
byte[] serialForm = hzMap.get(key);
if (serialForm != null) {
info = unmarshal(serialForm);
}
}
return info;
}
private Info unmarshal(byte[] serialForm) {
Info info;
try {
info = persister.load(new ByteArrayInputStream(serialForm), Info.class);
} catch (IOException e) {
throw Throwables.propagate(e);
}
return info;
}
@Override
public Info get(String key, Callable<? extends Info> valueLoader) throws ExecutionException {
Info value = getIfPresent(key);
if (value == null) {
try {
value = valueLoader.call();
put(key, value);
} catch (Exception e) {
throw new ExecutionException(e);
}
}
return value;
}
@Override
public ImmutableMap<String, Info> getAllPresent(Iterable<?> keys) {
if (available()) {
Set<String> set = new HashSet<String>();
for (Object k : keys) {
set.add((String) k);
}
Map<String, byte[]> allPresent = hzMap.getAll(set);
Function<byte[], Info> function = new Function<byte[], Info>() {
@Override
public Info apply(byte[] input) {
return unmarshal(input);
}
};
Map<String, Info> transformedValues = Maps.transformValues(allPresent, function);
return ImmutableMap.copyOf(transformedValues);
}
return ImmutableMap.of();
}
@Override
public void put(String key, Info value) {
if (available()) {
byte[] serialForm = serialize(value);
hzMap.putTransient(key, serialForm, ttl, timeunit);
}
}
private byte[] serialize(Info value) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
persister.save(value, out);
} catch (IOException e) {
throw Throwables.propagate(e);
}
byte[] serialForm = out.toByteArray();
return serialForm;
}
@Override
public void putAll(Map<? extends String, ? extends Info> m) {
Function<Info, byte[]> f = new Function<Info, byte[]>() {
@Override
public byte[] apply(Info input) {
return serialize(input);
}
};
Map<? extends String, byte[]> map = Maps.transformValues(m, f);
hzMap.putAll(map);
}
@Override
public void invalidate(Object key) {
if (available()) {
hzMap.remove(String.valueOf(key));
}
}
@Override
public void invalidateAll(Iterable<?> keys) {
if (available()) {
for (Object k : keys) {
hzMap.remove(String.valueOf(k));
}
}
}
@Override
public void invalidateAll() {
if (available()) {
hzMap.clear();
}
}
@Override
public long size() {
return available() ? hzMap.size() : 0L;
}
@Override
public CacheStats stats() {
throw new UnsupportedOperationException();
}
@Override
public ConcurrentMap<String, Info> asMap() {
if (available()) {
Function<byte[], Info> function = new Function<byte[], Info>() {
@Override
public Info apply(byte[] input) {
return unmarshal(input);
}
};
Map<String, Info> transformedValues = Maps.transformValues(hzMap, function);
return new ConcurrentHashMap<String, Info>(transformedValues);
}
return Maps.newConcurrentMap();
}
@Override
public void cleanUp() {
//
}
}
private static class NullCache<K, V> implements Cache<K, V> {
@Override
public V get(K key, Callable<? extends V> valueLoader) throws ExecutionException {
try {
V value = valueLoader.call();
return value;
} catch (Exception e) {
throw new ExecutionException(e);
}
}
@Override
public V getIfPresent(Object key) {
return null;
}
@Override
public long size() {
return 0L;
}
@Override
public void invalidate(Object key) {
//
}
@Override
public void invalidateAll() {
//
}
@Override
public void put(K key, V value) {
//
}
@Override
public ImmutableMap<K, V> getAllPresent(Iterable<?> keys) {
return ImmutableMap.of();
}
@Override
public void invalidateAll(Iterable<?> keys) {
//
}
@Override
public CacheStats stats() {
throw new UnsupportedOperationException();
}
@Override
public ConcurrentMap<K, V> asMap() {
throw new UnsupportedOperationException();
}
@Override
public void cleanUp() {
//
}
@Override
public void putAll(Map<? extends K, ? extends V> m) {
//
}
}
}