/*
* Copyright 2015 Ben Manes. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.benmanes.caffeine.cache.testing;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Function;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.github.benmanes.caffeine.cache.Policy;
import com.github.benmanes.caffeine.cache.RemovalCause;
import com.github.benmanes.caffeine.cache.stats.CacheStats;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.CacheWeigher;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Expire;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.InitialCapacity;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Listener;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Maximum;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.ReferenceType;
import com.google.common.base.Ticker;
import com.google.common.cache.AbstractCache.SimpleStatsCounter;
import com.google.common.cache.AbstractCache.StatsCounter;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.RemovalListener;
import com.google.common.cache.RemovalNotification;
import com.google.common.cache.Weigher;
import com.google.common.collect.ForwardingConcurrentMap;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.UncheckedExecutionException;
/**
* @author ben.manes@gmail.com (Ben Manes)
*/
public final class GuavaCacheFromContext {
private GuavaCacheFromContext() {}
/** Returns a Guava-backed cache. */
@SuppressWarnings("CheckReturnValue")
public static <K, V> Cache<K, V> newGuavaCache(CacheContext context) {
checkState(!context.isAsync(), "Guava caches are synchronous only");
CacheBuilder<Object, Object> builder = CacheBuilder.newBuilder();
context.guava = builder;
builder.concurrencyLevel(1);
if (context.initialCapacity != InitialCapacity.DEFAULT) {
builder.initialCapacity(context.initialCapacity.size());
}
if (context.isRecordingStats()) {
builder.recordStats();
}
if (context.maximumSize != Maximum.DISABLED) {
if (context.weigher == CacheWeigher.DEFAULT) {
builder.maximumSize(context.maximumSize.max());
} else {
builder.weigher(new GuavaWeigher<Object, Object>(context.weigher));
builder.maximumWeight(context.maximumWeight());
}
}
if (context.afterAccess != Expire.DISABLED) {
builder.expireAfterAccess(context.afterAccess.timeNanos(), TimeUnit.NANOSECONDS);
}
if (context.afterWrite != Expire.DISABLED) {
builder.expireAfterWrite(context.afterWrite.timeNanos(), TimeUnit.NANOSECONDS);
}
if (context.refresh != Expire.DISABLED) {
builder.refreshAfterWrite(context.refresh.timeNanos(), TimeUnit.NANOSECONDS);
}
if (context.expires() || context.refreshes()) {
builder.ticker(context.ticker());
}
if (context.keyStrength == ReferenceType.WEAK) {
builder.weakKeys();
} else if (context.keyStrength == ReferenceType.SOFT) {
throw new IllegalStateException();
}
if (context.valueStrength == ReferenceType.WEAK) {
builder.weakValues();
} else if (context.valueStrength == ReferenceType.SOFT) {
builder.softValues();
}
if (context.removalListenerType != Listener.DEFAULT) {
boolean translateZeroExpire = (context.afterAccess == Expire.IMMEDIATELY) ||
(context.afterWrite == Expire.IMMEDIATELY);
builder.removalListener(new GuavaRemovalListener<>(
translateZeroExpire, context.removalListener));
}
Ticker ticker = (context.ticker == null) ? Ticker.systemTicker() : context.ticker();
if (context.loader == null) {
context.cache = new GuavaCache<>(builder.<Integer, Integer>build(),
ticker, context.isRecordingStats());
} else if (context.loader().isBulk()) {
context.cache = new GuavaLoadingCache<>(builder.build(
new BulkLoader<Integer, Integer>(context.loader())),
ticker, context.isRecordingStats());
} else {
context.cache = new GuavaLoadingCache<>(builder.build(
new SingleLoader<Integer, Integer>(context.loader())),
ticker, context.isRecordingStats());
}
@SuppressWarnings("unchecked")
Cache<K, V> castedCache = (Cache<K, V>) context.cache;
return castedCache;
}
static class GuavaCache<K, V> implements Cache<K, V>, Serializable {
private static final long serialVersionUID = 1L;
private final com.google.common.cache.Cache<K, V> cache;
private final boolean isRecordingStats;
private final Ticker ticker;
transient StatsCounter statsCounter;
GuavaCache(com.google.common.cache.Cache<K, V> cache, Ticker ticker, boolean isRecordingStats) {
this.statsCounter = new SimpleStatsCounter();
this.isRecordingStats = isRecordingStats;
this.cache = requireNonNull(cache);
this.ticker = ticker;
}
@Override
public V getIfPresent(Object key) {
return cache.getIfPresent(key);
}
@Override
public V get(K key, Function<? super K, ? extends V> mappingFunction) {
requireNonNull(mappingFunction);
try {
return cache.get(key, () -> {
V value = mappingFunction.apply(key);
if (value == null) {
throw new CacheMissException();
}
return value;
});
} catch (UncheckedExecutionException e) {
if (e.getCause() instanceof CacheMissException) {
return null;
}
throw (RuntimeException) e.getCause();
} catch (ExecutionException e) {
throw new CompletionException(e);
} catch (ExecutionError e) {
throw (Error) e.getCause();
}
}
@Override
public Map<K, V> getAllPresent(Iterable<?> keys) {
requireNonNull(keys);
keys.forEach(Objects::requireNonNull);
return cache.getAllPresent(keys);
}
@Override
public void put(K key, V value) {
requireNonNull(key);
requireNonNull(value);
cache.put(key, value);
}
@Override
public void putAll(Map<? extends K, ? extends V> map) {
cache.putAll(map);
}
@Override
public void invalidate(Object key) {
cache.invalidate(key);
}
@Override
public void invalidateAll(Iterable<?> keys) {
keys.forEach(this::invalidate);
}
@Override
public void invalidateAll() {
cache.invalidateAll();
}
@Override
public long estimatedSize() {
return cache.size();
}
@Override
public CacheStats stats() {
com.google.common.cache.CacheStats stats = statsCounter.snapshot().plus(cache.stats());
return new CacheStats(stats.hitCount(), stats.missCount(), stats.loadSuccessCount(),
stats.loadExceptionCount(), stats.totalLoadTime(), stats.evictionCount(), 0L);
}
@Override
public ConcurrentMap<K, V> asMap() {
return new ForwardingConcurrentMap<K, V>() {
@Override
public boolean containsKey(Object key) {
requireNonNull(key);
return delegate().containsKey(key);
}
@Override
public boolean containsValue(Object value) {
requireNonNull(value);
return delegate().containsValue(value);
}
@Override
public V get(Object key) {
requireNonNull(key);
return delegate().get(key);
}
@Override
public V remove(Object key) {
requireNonNull(key);
return delegate().remove(key);
}
@Override
public boolean remove(Object key, Object value) {
requireNonNull(key);
return delegate().remove(key, value);
}
@Override
public boolean replace(K key, V oldValue, V newValue) {
requireNonNull(oldValue);
return delegate().replace(key, oldValue, newValue);
}
@Override
public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
requireNonNull(mappingFunction);
V value = getIfPresent(key);
if (value != null) {
return value;
}
long now = ticker.read();
try {
value = mappingFunction.apply(key);
long loadTime = (ticker.read() - now);
if (value == null) {
statsCounter.recordLoadException(loadTime);
return null;
} else {
statsCounter.recordLoadSuccess(loadTime);
V v = delegate().putIfAbsent(key, value);
return (v == null) ? value : v;
}
} catch (RuntimeException | Error e) {
statsCounter.recordLoadException((ticker.read() - now));
throw e;
}
}
@Override
public V computeIfPresent(K key,
BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
requireNonNull(remappingFunction);
V oldValue;
long now = ticker.read();
if ((oldValue = get(key)) != null) {
try {
V newValue = remappingFunction.apply(key, oldValue);
long loadTime = ticker.read() - now;
if (newValue == null) {
statsCounter.recordLoadException(loadTime);
remove(key);
return null;
} else {
statsCounter.recordLoadSuccess(loadTime);
put(key, newValue);
return newValue;
}
} catch (RuntimeException | Error e) {
statsCounter.recordLoadException(ticker.read() - now);
throw e;
}
} else {
return null;
}
}
@Override
public V compute(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
requireNonNull(remappingFunction);
V oldValue = get(key);
long now = ticker.read();
try {
V newValue = remappingFunction.apply(key, oldValue);
if (newValue == null) {
if (oldValue != null || containsKey(key)) {
remove(key);
statsCounter.recordLoadException(ticker.read() - now);
return null;
} else {
statsCounter.recordLoadException(ticker.read() - now);
return null;
}
} else {
statsCounter.recordLoadSuccess(ticker.read() - now);
put(key, newValue);
return newValue;
}
} catch (RuntimeException | Error e) {
statsCounter.recordLoadException(ticker.read() - now);
throw e;
}
}
@Override
public V merge(K key, V value,
BiFunction<? super V, ? super V, ? extends V> remappingFunction) {
requireNonNull(remappingFunction);
requireNonNull(value);
V oldValue = get(key);
for (;;) {
if (oldValue != null) {
long now = ticker.read();
try {
V newValue = remappingFunction.apply(oldValue, value);
if (newValue != null) {
if (replace(key, oldValue, newValue)) {
statsCounter.recordLoadSuccess(ticker.read() - now);
return newValue;
}
} else if (remove(key, oldValue)) {
statsCounter.recordLoadException(ticker.read() - now);
return null;
}
} catch (RuntimeException | Error e) {
statsCounter.recordLoadException(ticker.read() - now);
throw e;
}
oldValue = get(key);
} else {
if ((oldValue = putIfAbsent(key, value)) == null) {
return value;
}
}
}
}
@Override
protected ConcurrentMap<K, V> delegate() {
return cache.asMap();
}
private void readObject(ObjectInputStream stream) throws InvalidObjectException {
statsCounter = new SimpleStatsCounter();
}
};
}
@Override
public void cleanUp() {
cache.cleanUp();
}
@Override
public Policy<K, V> policy() {
return new Policy<K, V>() {
@Override public boolean isRecordingStats() {
return isRecordingStats;
}
@Override public Optional<Eviction<K, V>> eviction() {
return Optional.empty();
}
@Override public Optional<Expiration<K, V>> expireAfterAccess() {
return Optional.empty();
}
@Override public Optional<Expiration<K, V>> expireAfterWrite() {
return Optional.empty();
}
@Override public Optional<Expiration<K, V>> refreshAfterWrite() {
return Optional.empty();
}
};
}
}
static class GuavaLoadingCache<K, V> extends GuavaCache<K, V>
implements LoadingCache<K, V>, Serializable {
private static final long serialVersionUID = 1L;
private final com.google.common.cache.LoadingCache<K, V> cache;
GuavaLoadingCache(com.google.common.cache.LoadingCache<K, V> cache,
Ticker ticker, boolean isRecordingStats) {
super(cache, ticker, isRecordingStats);
this.cache = requireNonNull(cache);
}
@Override
public V get(K key) {
try {
return cache.get(key);
} catch (UncheckedExecutionException e) {
if (e.getCause() instanceof CacheMissException) {
return null;
}
throw (RuntimeException) e.getCause();
} catch (ExecutionException e) {
throw new CompletionException(e);
} catch (ExecutionError e) {
throw (Error) e.getCause();
}
}
@Override
public Map<K, V> getAll(Iterable<? extends K> keys) {
try {
return cache.getAll(keys);
} catch (UncheckedExecutionException e) {
if (e.getCause() instanceof CacheMissException) {
return ImmutableMap.of();
}
throw (RuntimeException) e.getCause();
} catch (ExecutionException e) {
throw new CompletionException(e);
} catch (ExecutionError e) {
throw (Error) e.getCause();
}
}
@Override
public void refresh(K key) {
cache.refresh(key);
}
}
static final class GuavaWeigher<K, V> implements Weigher<K, V>, Serializable {
private static final long serialVersionUID = 1L;
final com.github.benmanes.caffeine.cache.Weigher<K, V> weigher;
GuavaWeigher(com.github.benmanes.caffeine.cache.Weigher<K, V> weigher) {
this.weigher = weigher;
}
@Override public int weigh(K key, V value) {
return weigher.weigh(key, value);
}
}
static final class GuavaRemovalListener<K, V> implements RemovalListener<K, V>, Serializable {
private static final long serialVersionUID = 1L;
final com.github.benmanes.caffeine.cache.RemovalListener<K, V> delegate;
final boolean translateZeroExpire;
GuavaRemovalListener(boolean translateZeroExpire,
com.github.benmanes.caffeine.cache.RemovalListener<K, V> delegate) {
this.translateZeroExpire = translateZeroExpire;
this.delegate = delegate;
}
@Override
public void onRemoval(RemovalNotification<K, V> notification) {
RemovalCause cause = RemovalCause.valueOf(notification.getCause().name());
if (translateZeroExpire && (cause == RemovalCause.SIZE)) {
// Guava internally uses sizing logic for null cache case
cause = RemovalCause.EXPIRED;
}
delegate.onRemoval(notification.getKey(), notification.getValue(), cause);
}
}
static class SingleLoader<K, V> extends CacheLoader<K, V> implements Serializable {
private static final long serialVersionUID = 1L;
final com.github.benmanes.caffeine.cache.CacheLoader<K, V> delegate;
SingleLoader(com.github.benmanes.caffeine.cache.CacheLoader<K, V> delegate) {
this.delegate = delegate;
}
@Override
public V load(K key) throws Exception {
V value = delegate.load(key);
if (value == null) {
throw new CacheMissException();
}
return value;
}
}
static class BulkLoader<K, V> extends SingleLoader<K, V> {
private static final long serialVersionUID = 1L;
BulkLoader(com.github.benmanes.caffeine.cache.CacheLoader<K, V> delegate) {
super(delegate);
}
@Override
public Map<K, V> loadAll(Iterable<? extends K> keys) throws Exception {
return delegate.loadAll(keys);
}
}
static final class CacheMissException extends RuntimeException {
private static final long serialVersionUID = 1L;
}
}