/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.common.util.concurrent; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.util.concurrent.KeyedLock; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; public class KeyedLockTests extends ESTestCase { public void testIfMapEmptyAfterLotsOfAcquireAndReleases() throws InterruptedException { ConcurrentHashMap<String, Integer> counter = new ConcurrentHashMap<>(); ConcurrentHashMap<String, AtomicInteger> safeCounter = new ConcurrentHashMap<>(); KeyedLock<String> connectionLock = new KeyedLock<String>(randomBoolean()); String[] names = new String[randomIntBetween(1, 40)]; for (int i = 0; i < names.length; i++) { names[i] = randomRealisticUnicodeOfLengthBetween(10, 20); } CountDownLatch startLatch = new CountDownLatch(1); int numThreads = randomIntBetween(3, 10); AcquireAndReleaseThread[] threads = new AcquireAndReleaseThread[numThreads]; for (int i = 0; i < numThreads; i++) { threads[i] = new AcquireAndReleaseThread(startLatch, connectionLock, names, counter, safeCounter); } for (int i = 0; i < numThreads; i++) { threads[i].start(); } startLatch.countDown(); for (int i = 0; i < numThreads; i++) { threads[i].join(); } assertThat(connectionLock.hasLockedKeys(), equalTo(false)); Set<Entry<String, Integer>> entrySet = counter.entrySet(); assertThat(counter.size(), equalTo(safeCounter.size())); for (Entry<String, Integer> entry : entrySet) { AtomicInteger atomicInteger = safeCounter.get(entry.getKey()); assertThat(atomicInteger, not(Matchers.nullValue())); assertThat(atomicInteger.get(), equalTo(entry.getValue())); } } public static class AcquireAndReleaseThread extends Thread { private CountDownLatch startLatch; KeyedLock<String> connectionLock; String[] names; ConcurrentHashMap<String, Integer> counter; ConcurrentHashMap<String, AtomicInteger> safeCounter; public AcquireAndReleaseThread(CountDownLatch startLatch, KeyedLock<String> connectionLock, String[] names, ConcurrentHashMap<String, Integer> counter, ConcurrentHashMap<String, AtomicInteger> safeCounter) { this.startLatch = startLatch; this.connectionLock = connectionLock; this.names = names; this.counter = counter; this.safeCounter = safeCounter; } @Override public void run() { try { startLatch.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } int numRuns = scaledRandomIntBetween(5000, 50000); for (int i = 0; i < numRuns; i++) { String curName = names[randomInt(names.length - 1)]; assert connectionLock.isHeldByCurrentThread(curName) == false; try (Releasable ignored = connectionLock.acquire(curName)) { assert connectionLock.isHeldByCurrentThread(curName); assert connectionLock.isHeldByCurrentThread(curName + "bla") == false; Integer integer = counter.get(curName); if (integer == null) { counter.put(curName, 1); } else { counter.put(curName, integer.intValue() + 1); } } AtomicInteger atomicInteger = new AtomicInteger(0); AtomicInteger value = safeCounter.putIfAbsent(curName, atomicInteger); if (value == null) { atomicInteger.incrementAndGet(); } else { value.incrementAndGet(); } } } } }