/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.cache;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.stream.Collectors;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is;
public class CacheTests extends ESTestCase {
private int numberOfEntries;
@Before
public void setUp() throws Exception {
super.setUp();
numberOfEntries = randomIntBetween(1000, 10000);
logger.debug("numberOfEntries: {}", numberOfEntries);
}
// cache some entries, then randomly lookup keys that do not exist, then check the stats
public void testCacheStats() {
AtomicLong evictions = new AtomicLong();
Set<Integer> keys = new HashSet<>();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.setMaximumWeight(numberOfEntries / 2)
.removalListener(notification -> {
keys.remove(notification.getKey());
evictions.incrementAndGet();
})
.build();
for (int i = 0; i < numberOfEntries; i++) {
// track the keys, which will be removed upon eviction (see the RemovalListener)
keys.add(i);
cache.put(i, Integer.toString(i));
}
long hits = 0;
long misses = 0;
Integer missingKey = 0;
for (Integer key : keys) {
--missingKey;
if (rarely()) {
misses++;
cache.get(missingKey);
} else {
hits++;
cache.get(key);
}
}
assertEquals(hits, cache.stats().getHits());
assertEquals(misses, cache.stats().getMisses());
assertEquals((long) Math.ceil(numberOfEntries / 2.0), evictions.get());
assertEquals(evictions.get(), cache.stats().getEvictions());
}
// cache some entries in batches of size maximumWeight; for each batch, touch the even entries to affect the
// ordering; upon the next caching of entries, the entries from the previous batch will be evicted; we can then
// check that the evicted entries were evicted in LRU order (first the odds in a batch, then the evens in a batch)
// for each batch
public void testCacheEvictions() {
int maximumWeight = randomIntBetween(1, numberOfEntries);
AtomicLong evictions = new AtomicLong();
List<Integer> evictedKeys = new ArrayList<>();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.setMaximumWeight(maximumWeight)
.removalListener(notification -> {
evictions.incrementAndGet();
evictedKeys.add(notification.getKey());
})
.build();
// cache entries up to numberOfEntries - maximumWeight; all of these entries will ultimately be evicted in
// batches of size maximumWeight, first the odds in the batch, then the evens in the batch
List<Integer> expectedEvictions = new ArrayList<>();
int iterations = (int)Math.ceil((numberOfEntries - maximumWeight) / (1.0 * maximumWeight));
for (int i = 0; i < iterations; i++) {
for (int j = i * maximumWeight; j < (i + 1) * maximumWeight && j < numberOfEntries - maximumWeight; j++) {
cache.put(j, Integer.toString(j));
if (j % 2 == 1) {
expectedEvictions.add(j);
}
}
for (int j = i * maximumWeight; j < (i + 1) * maximumWeight && j < numberOfEntries - maximumWeight; j++) {
if (j % 2 == 0) {
cache.get(j);
expectedEvictions.add(j);
}
}
}
// finish filling the cache
for (int i = numberOfEntries - maximumWeight; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
assertEquals(numberOfEntries - maximumWeight, evictions.get());
assertEquals(evictions.get(), cache.stats().getEvictions());
// assert that the keys were evicted in LRU order
Set<Integer> keys = new HashSet<>();
List<Integer> remainingKeys = new ArrayList<>();
for (Integer key : cache.keys()) {
keys.add(key);
remainingKeys.add(key);
}
assertEquals(expectedEvictions.size(), evictedKeys.size());
for (int i = 0; i < expectedEvictions.size(); i++) {
assertFalse(keys.contains(expectedEvictions.get(i)));
assertEquals(expectedEvictions.get(i), evictedKeys.get(i));
}
for (int i = numberOfEntries - maximumWeight; i < numberOfEntries; i++) {
assertTrue(keys.contains(i));
assertEquals(
numberOfEntries - i + (numberOfEntries - maximumWeight) - 1,
(int) remainingKeys.get(i - (numberOfEntries - maximumWeight))
);
}
}
// cache some entries and exceed the maximum weight, then check that the cache has the expected weight and the
// expected evictions occurred
public void testWeigher() {
int maximumWeight = 2 * numberOfEntries;
int weight = randomIntBetween(2, 10);
AtomicLong evictions = new AtomicLong();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.setMaximumWeight(maximumWeight)
.weigher((k, v) -> weight)
.removalListener(notification -> evictions.incrementAndGet())
.build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
// cache weight should be the largest multiple of weight less than maximumWeight
assertEquals(weight * (maximumWeight / weight), cache.weight());
// the number of evicted entries should be the number of entries that fit in the excess weight
assertEquals((int) Math.ceil((weight - 2) * numberOfEntries / (1.0 * weight)), evictions.get());
assertEquals(evictions.get(), cache.stats().getEvictions());
}
// cache some entries, randomly invalidate some of them, then check that the weight of the cache is correct
public void testWeight() {
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.weigher((k, v) -> k)
.build();
int weight = 0;
for (int i = 0; i < numberOfEntries; i++) {
weight += i;
cache.put(i, Integer.toString(i));
}
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
weight -= i;
cache.invalidate(i);
}
}
assertEquals(weight, cache.weight());
}
// cache some entries, randomly invalidate some of them, then check that the number of cached entries is correct
public void testCount() {
Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
int count = 0;
for (int i = 0; i < numberOfEntries; i++) {
count++;
cache.put(i, Integer.toString(i));
}
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
count--;
cache.invalidate(i);
}
}
assertEquals(count, cache.count());
}
// cache some entries, step the clock forward, cache some more entries, step the clock forward and then check that
// the first batch of cached entries expired and were removed
public void testExpirationAfterAccess() {
AtomicLong now = new AtomicLong();
Cache<Integer, String> cache = new Cache<Integer, String>() {
@Override
protected long now() {
return now.get();
}
};
cache.setExpireAfterAccessNanos(1);
List<Integer> evictedKeys = new ArrayList<>();
cache.setRemovalListener(notification -> {
assertEquals(RemovalNotification.RemovalReason.EVICTED, notification.getRemovalReason());
evictedKeys.add(notification.getKey());
});
now.set(0);
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
now.set(1);
for (int i = numberOfEntries; i < 2 * numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
now.set(2);
cache.refresh();
assertEquals(numberOfEntries, cache.count());
for (int i = 0; i < evictedKeys.size(); i++) {
assertEquals(i, (int) evictedKeys.get(i));
}
Set<Integer> remainingKeys = new HashSet<>();
for (Integer key : cache.keys()) {
remainingKeys.add(key);
}
for (int i = numberOfEntries; i < 2 * numberOfEntries; i++) {
assertTrue(remainingKeys.contains(i));
}
}
public void testSimpleExpireAfterAccess() {
AtomicLong now = new AtomicLong();
Cache<Integer, String> cache = new Cache<Integer, String>() {
@Override
protected long now() {
return now.get();
}
};
cache.setExpireAfterAccessNanos(1);
now.set(0);
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
for (int i = 0; i < numberOfEntries; i++) {
assertEquals(cache.get(i), Integer.toString(i));
}
now.set(2);
for(int i = 0; i < numberOfEntries; i++) {
assertNull(cache.get(i));
}
}
public void testExpirationAfterWrite() {
AtomicLong now = new AtomicLong();
Cache<Integer, String> cache = new Cache<Integer, String>() {
@Override
protected long now() {
return now.get();
}
};
cache.setExpireAfterWriteNanos(1);
List<Integer> evictedKeys = new ArrayList<>();
cache.setRemovalListener(notification -> {
assertEquals(RemovalNotification.RemovalReason.EVICTED, notification.getRemovalReason());
evictedKeys.add(notification.getKey());
});
now.set(0);
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
now.set(1);
for (int i = numberOfEntries; i < 2 * numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
now.set(2);
for (int i = 0; i < numberOfEntries; i++) {
cache.get(i);
}
cache.refresh();
assertEquals(numberOfEntries, cache.count());
for (int i = 0; i < evictedKeys.size(); i++) {
assertEquals(i, (int) evictedKeys.get(i));
}
Set<Integer> remainingKeys = new HashSet<>();
for (Integer key : cache.keys()) {
remainingKeys.add(key);
}
for (int i = numberOfEntries; i < 2 * numberOfEntries; i++) {
assertTrue(remainingKeys.contains(i));
}
}
// randomly promote some entries, step the clock forward, then check that the promoted entries remain and the
// non-promoted entries were removed
public void testPromotion() {
AtomicLong now = new AtomicLong();
Cache<Integer, String> cache = new Cache<Integer, String>() {
@Override
protected long now() {
return now.get();
}
};
cache.setExpireAfterAccessNanos(1);
now.set(0);
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
now.set(1);
Set<Integer> promotedKeys = new HashSet<>();
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
cache.get(i);
promotedKeys.add(i);
}
}
now.set(2);
cache.refresh();
assertEquals(promotedKeys.size(), cache.count());
for (int i = 0; i < numberOfEntries; i++) {
if (promotedKeys.contains(i)) {
assertNotNull(cache.get(i));
} else {
assertNull(cache.get(i));
}
}
}
// randomly invalidate some cached entries, then check that a lookup for each of those and only those keys is null
public void testInvalidate() {
Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
Set<Integer> keys = new HashSet<>();
for (Integer key : cache.keys()) {
if (rarely()) {
cache.invalidate(key);
keys.add(key);
}
}
for (int i = 0; i < numberOfEntries; i++) {
if (keys.contains(i)) {
assertNull(cache.get(i));
} else {
assertNotNull(cache.get(i));
}
}
}
// randomly invalidate some cached entries, then check that we receive invalidate notifications for those and only
// those entries
public void testNotificationOnInvalidate() {
Set<Integer> notifications = new HashSet<>();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.removalListener(notification -> {
assertEquals(RemovalNotification.RemovalReason.INVALIDATED, notification.getRemovalReason());
notifications.add(notification.getKey());
})
.build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
Set<Integer> invalidated = new HashSet<>();
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
cache.invalidate(i);
invalidated.add(i);
}
}
assertEquals(notifications, invalidated);
}
// invalidate all cached entries, then check that the cache is empty
public void testInvalidateAll() {
Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
cache.invalidateAll();
assertEquals(0, cache.count());
assertEquals(0, cache.weight());
}
// invalidate all cached entries, then check that we receive invalidate notifications for all entries
public void testNotificationOnInvalidateAll() {
Set<Integer> notifications = new HashSet<>();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.removalListener(notification -> {
assertEquals(RemovalNotification.RemovalReason.INVALIDATED, notification.getRemovalReason());
notifications.add(notification.getKey());
})
.build();
Set<Integer> invalidated = new HashSet<>();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
invalidated.add(i);
}
cache.invalidateAll();
assertEquals(invalidated, notifications);
}
// randomly replace some entries, increasing the weight by 1 for each replacement, then count that the cache size
// is correct
public void testReplaceRecomputesSize() {
class Value {
private String value;
private long weight;
Value(String value, long weight) {
this.value = value;
this.weight = weight;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Value that = (Value) o;
return value.equals(that.value);
}
@Override
public int hashCode() {
return value.hashCode();
}
}
Cache<Integer, Value> cache = CacheBuilder.<Integer, Value>builder().weigher((k, s) -> s.weight).build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, new Value(Integer.toString(i), 1));
}
assertEquals(numberOfEntries, cache.count());
assertEquals(numberOfEntries, cache.weight());
int replaced = 0;
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
replaced++;
cache.put(i, new Value(Integer.toString(i), 2));
}
}
assertEquals(numberOfEntries, cache.count());
assertEquals(numberOfEntries + replaced, cache.weight());
}
// randomly replace some entries, then check that we received replacement notifications for those and only those
// entries
public void testNotificationOnReplace() {
Set<Integer> notifications = new HashSet<>();
Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.removalListener(notification -> {
assertEquals(RemovalNotification.RemovalReason.REPLACED, notification.getRemovalReason());
notifications.add(notification.getKey());
})
.build();
for (int i = 0; i < numberOfEntries; i++) {
cache.put(i, Integer.toString(i));
}
Set<Integer> replacements = new HashSet<>();
for (int i = 0; i < numberOfEntries; i++) {
if (rarely()) {
cache.put(i, Integer.toString(i) + Integer.toString(i));
replacements.add(i);
}
}
assertEquals(replacements, notifications);
}
public void testComputeIfAbsentLoadsSuccessfully() {
Map<Integer, Integer> map = new HashMap<>();
Cache<Integer, Integer> cache = CacheBuilder.<Integer, Integer>builder().build();
for (int i = 0; i < numberOfEntries; i++) {
try {
cache.computeIfAbsent(i, k -> {
int value = randomInt();
map.put(k, value);
return value;
});
} catch (ExecutionException e) {
throw new AssertionError(e);
}
}
for (int i = 0; i < numberOfEntries; i++) {
assertEquals(map.get(i), cache.get(i));
}
}
public void testComputeIfAbsentCallsOnce() throws BrokenBarrierException, InterruptedException {
int numberOfThreads = randomIntBetween(2, 32);
final Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
AtomicReferenceArray flags = new AtomicReferenceArray(numberOfEntries);
for (int j = 0; j < numberOfEntries; j++) {
flags.set(j, false);
}
CopyOnWriteArrayList<ExecutionException> failures = new CopyOnWriteArrayList<>();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
Thread thread = new Thread(() -> {
try {
barrier.await();
for (int j = 0; j < numberOfEntries; j++) {
try {
cache.computeIfAbsent(j, key -> {
assertTrue(flags.compareAndSet(key, false, true));
return Integer.toString(key);
});
} catch (ExecutionException e) {
failures.add(e);
break;
}
}
barrier.await();
} catch (BrokenBarrierException | InterruptedException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
assertThat(failures, is(empty()));
}
public void testComputeIfAbsentThrowsExceptionIfLoaderReturnsANullValue() {
final Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
try {
cache.computeIfAbsent(1, k -> null);
fail("expected ExecutionException");
} catch (ExecutionException e) {
assertThat(e.getCause(), instanceOf(NullPointerException.class));
}
}
public void testDependentKeyDeadlock() throws BrokenBarrierException, InterruptedException {
class Key {
private final int key;
Key(int key) {
this.key = key;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Key key1 = (Key) o;
return key == key1.key;
}
@Override
public int hashCode() {
return key % 2;
}
}
int numberOfThreads = randomIntBetween(2, 32);
final Cache<Key, Integer> cache = CacheBuilder.<Key, Integer>builder().build();
CopyOnWriteArrayList<ExecutionException> failures = new CopyOnWriteArrayList<>();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
CountDownLatch deadlockLatch = new CountDownLatch(numberOfThreads);
List<Thread> threads = new ArrayList<>();
for (int i = 0; i < numberOfThreads; i++) {
Thread thread = new Thread(() -> {
try {
try {
barrier.await();
} catch (BrokenBarrierException | InterruptedException e) {
throw new AssertionError(e);
}
Random random = new Random(random().nextLong());
for (int j = 0; j < numberOfEntries; j++) {
Key key = new Key(random.nextInt(numberOfEntries));
try {
cache.computeIfAbsent(key, k -> {
if (k.key == 0) {
return 0;
} else {
Integer value = cache.get(new Key(k.key / 2));
return value != null ? value : 0;
}
});
} catch (ExecutionException e) {
failures.add(e);
break;
}
}
} finally {
// successfully avoided deadlock, release the main thread
deadlockLatch.countDown();
}
});
threads.add(thread);
thread.start();
}
AtomicBoolean deadlock = new AtomicBoolean();
assert !deadlock.get();
// start a watchdog service
ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
scheduler.scheduleAtFixedRate(() -> {
Set<Long> ids = threads.stream().map(t -> t.getId()).collect(Collectors.toSet());
ThreadMXBean mxBean = ManagementFactory.getThreadMXBean();
long[] deadlockedThreads = mxBean.findDeadlockedThreads();
if (!deadlock.get() && deadlockedThreads != null) {
for (long deadlockedThread : deadlockedThreads) {
// ensure that we detected deadlock on our threads
if (ids.contains(deadlockedThread)) {
deadlock.set(true);
// release the main test thread to fail the test
for (int i = 0; i < numberOfThreads; i++) {
deadlockLatch.countDown();
}
break;
}
}
}
}, 1, 1, TimeUnit.SECONDS);
// everything is setup, release the hounds
barrier.await();
// wait for either deadlock to be detected or the threads to terminate
deadlockLatch.await();
// shutdown the watchdog service
scheduler.shutdown();
assertThat(failures, is(empty()));
assertFalse("deadlock", deadlock.get());
}
public void testCachePollution() throws BrokenBarrierException, InterruptedException {
int numberOfThreads = randomIntBetween(2, 32);
final Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
Thread thread = new Thread(() -> {
try {
barrier.await();
Random random = new Random(random().nextLong());
for (int j = 0; j < numberOfEntries; j++) {
Integer key = random.nextInt(numberOfEntries);
boolean first;
boolean second;
do {
first = random.nextBoolean();
second = random.nextBoolean();
} while (first && second);
if (first) {
try {
cache.computeIfAbsent(key, k -> {
if (random.nextBoolean()) {
return Integer.toString(k);
} else {
throw new Exception("testCachePollution");
}
});
} catch (ExecutionException e) {
assertNotNull(e.getCause());
assertThat(e.getCause(), instanceOf(Exception.class));
assertEquals(e.getCause().getMessage(), "testCachePollution");
}
} else if (second) {
cache.invalidate(key);
} else {
cache.get(key);
}
}
barrier.await();
} catch (BrokenBarrierException | InterruptedException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
}
public void testExceptionThrownDuringConcurrentComputeIfAbsent() throws BrokenBarrierException, InterruptedException {
int numberOfThreads = randomIntBetween(2, 32);
final Cache<String, String> cache = CacheBuilder.<String, String>builder().build();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
final String key = randomAlphaOfLengthBetween(2, 32);
for (int i = 0; i < numberOfThreads; i++) {
Thread thread = new Thread(() -> {
try {
barrier.await();
for (int j = 0; j < numberOfEntries; j++) {
try {
String value = cache.computeIfAbsent(key, k -> {
throw new RuntimeException("failed to load");
});
fail("expected exception but got: " + value);
} catch (ExecutionException e) {
assertNotNull(e.getCause());
assertThat(e.getCause(), instanceOf(RuntimeException.class));
assertEquals(e.getCause().getMessage(), "failed to load");
}
}
barrier.await();
} catch (BrokenBarrierException | InterruptedException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
}
// test that the cache is not corrupted under lots of concurrent modifications, even hitting the same key
// here be dragons: this test did catch one subtle bug during development; do not remove lightly
public void testTorture() throws BrokenBarrierException, InterruptedException {
int numberOfThreads = randomIntBetween(2, 32);
final Cache<Integer, String> cache =
CacheBuilder.<Integer, String>builder()
.setMaximumWeight(1000)
.weigher((k, v) -> 2)
.build();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
Thread thread = new Thread(() -> {
try {
barrier.await();
Random random = new Random(random().nextLong());
for (int j = 0; j < numberOfEntries; j++) {
Integer key = random.nextInt(numberOfEntries);
cache.put(key, Integer.toString(j));
}
barrier.await();
} catch (BrokenBarrierException | InterruptedException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
cache.refresh();
assertEquals(500, cache.count());
}
}