/*
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
package org.apache.bookkeeper.util.collections;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.StampedLock;
import java.util.function.LongPredicate;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
/**
* Concurrent hash map from primitive long to long
*
* Provides similar methods as a ConcurrentMap<K,V> but since it's an open hash map with linear probing, no node
* allocations are required to store the values.
* <p>
* Keys <strong>MUST</strong> be >= 0.
*/
public class ConcurrentLongLongHashMap {
private static final long EmptyKey = -1L;
private static final long DeletedKey = -2L;
private static final long ValueNotFound = -1L;
private static final float MapFillFactor = 0.66f;
private static final int DefaultExpectedItems = 256;
private static final int DefaultConcurrencyLevel = 16;
private final Section[] sections;
public static interface BiConsumerLong {
void accept(long key, long value);
}
public static interface LongLongFunction {
long apply(long key);
}
public static interface LongLongPredicate {
boolean test(long key, long value);
}
public ConcurrentLongLongHashMap() {
this(DefaultExpectedItems);
}
public ConcurrentLongLongHashMap(int expectedItems) {
this(expectedItems, DefaultConcurrencyLevel);
}
public ConcurrentLongLongHashMap(int expectedItems, int concurrencyLevel) {
checkArgument(expectedItems > 0);
checkArgument(concurrencyLevel > 0);
checkArgument(expectedItems >= concurrencyLevel);
int numSections = concurrencyLevel;
int perSectionExpectedItems = expectedItems / numSections;
int perSectionCapacity = (int) (perSectionExpectedItems / MapFillFactor);
this.sections = new Section[numSections];
for (int i = 0; i < numSections; i++) {
sections[i] = new Section(perSectionCapacity);
}
}
public long size() {
long size = 0;
for (Section s : sections) {
size += s.size;
}
return size;
}
public long capacity() {
long capacity = 0;
for (Section s : sections) {
capacity += s.capacity;
}
return capacity;
}
public boolean isEmpty() {
for (Section s : sections) {
if (s.size != 0) {
return false;
}
}
return true;
}
long getUsedBucketCount() {
long usedBucketCount = 0;
for (Section s : sections) {
usedBucketCount += s.usedBuckets;
}
return usedBucketCount;
}
/**
*
* @param key
* @return the value or -1 if the key was not present
*/
public long get(long key) {
checkBiggerEqualZero(key);
long h = hash(key);
return getSection(h).get(key, (int) h);
}
public boolean containsKey(long key) {
return get(key) != ValueNotFound;
}
public long put(long key, long value) {
checkBiggerEqualZero(key);
checkBiggerEqualZero(value);
long h = hash(key);
return getSection(h).put(key, value, (int) h, false, null);
}
public long putIfAbsent(long key, long value) {
checkBiggerEqualZero(key);
checkBiggerEqualZero(value);
long h = hash(key);
return getSection(h).put(key, value, (int) h, true, null);
}
public long computeIfAbsent(long key, LongLongFunction provider) {
checkBiggerEqualZero(key);
checkNotNull(provider);
long h = hash(key);
return getSection(h).put(key, ValueNotFound, (int) h, true, provider);
}
/**
* Atomically add the specified delta to a current value identified by the key. If the entry was not in the map, a
* new entry with default value 0 is added and then the delta is added.
*
* @param key
* the entry key
* @param delta
* the delta to add
* @return the new value of the entry
* @throws IllegalArgumentException
* if the delta was invalid, such as it would have caused the value to be < 0
*/
public long addAndGet(long key, long delta) {
checkBiggerEqualZero(key);
long h = hash(key);
return getSection(h).addAndGet(key, delta, (int) h);
}
/**
* Change the value for a specific key only if it matches the current value.
*
* @param key
* @param currentValue
* @param newValue
* @return
*/
public boolean compareAndSet(long key, long currentValue, long newValue) {
checkBiggerEqualZero(key);
checkBiggerEqualZero(newValue);
long h = hash(key);
return getSection(h).compareAndSet(key, currentValue, newValue, (int) h);
}
/**
* Remove an existing entry if found
*
* @param key
* @return the value associated with the key or -1 if key was not present
*/
public long remove(long key) {
checkBiggerEqualZero(key);
long h = hash(key);
return getSection(h).remove(key, ValueNotFound, (int) h);
}
public boolean remove(long key, long value) {
checkBiggerEqualZero(key);
checkBiggerEqualZero(value);
long h = hash(key);
return getSection(h).remove(key, value, (int) h) != ValueNotFound;
}
public int removeIf(LongPredicate filter) {
checkNotNull(filter);
int removedCount = 0;
for (Section s : sections) {
removedCount += s.removeIf(filter);
}
return removedCount;
}
public int removeIf(LongLongPredicate filter) {
checkNotNull(filter);
int removedCount = 0;
for (Section s : sections) {
removedCount += s.removeIf(filter);
}
return removedCount;
}
private final Section getSection(long hash) {
// Use 32 msb out of long to get the section
final int sectionIdx = (int) (hash >>> 32) & (sections.length - 1);
return sections[sectionIdx];
}
public void clear() {
for (Section s : sections) {
s.clear();
}
}
public void forEach(BiConsumerLong processor) {
for (Section s : sections) {
s.forEach(processor);
}
}
/**
* @return a new list of all keys (makes a copy)
*/
public List<Long> keys() {
List<Long> keys = Lists.newArrayList();
forEach((key, value) -> keys.add(key));
return keys;
}
public List<Long> values() {
List<Long> values = Lists.newArrayList();
forEach((key, value) -> values.add(value));
return values;
}
public Map<Long, Long> asMap() {
Map<Long, Long> map = Maps.newHashMap();
forEach((key, value) -> map.put(key, value));
return map;
}
// A section is a portion of the hash map that is covered by a single
@SuppressWarnings("serial")
private static final class Section extends StampedLock {
// Keys and values are stored interleaved in the table array
private long[] table;
private int capacity;
private volatile int size;
private int usedBuckets;
private int resizeThreshold;
Section(int capacity) {
this.capacity = alignToPowerOfTwo(capacity);
this.table = new long[2 * this.capacity];
this.size = 0;
this.usedBuckets = 0;
this.resizeThreshold = (int) (this.capacity * MapFillFactor);
Arrays.fill(table, EmptyKey);
}
long get(long key, int keyHash) {
long stamp = tryOptimisticRead();
boolean acquiredLock = false;
int bucket = signSafeMod(keyHash, capacity);
try {
while (true) {
// First try optimistic locking
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (!acquiredLock && validate(stamp)) {
// The values we have read are consistent
if (key == storedKey) {
return storedValue;
} else if (storedKey == EmptyKey) {
// Not found
return ValueNotFound;
}
} else {
// Fallback to acquiring read lock
if (!acquiredLock) {
stamp = readLock();
acquiredLock = true;
bucket = signSafeMod(keyHash, capacity);
storedKey = table[bucket];
storedValue = table[bucket + 1];
}
if (key == storedKey) {
return storedValue;
} else if (storedKey == EmptyKey) {
// Not found
return ValueNotFound;
}
}
bucket = (bucket + 2) & (table.length - 1);
}
} finally {
if (acquiredLock) {
unlockRead(stamp);
}
}
}
long put(long key, long value, int keyHash, boolean onlyIfAbsent, LongLongFunction valueProvider) {
long stamp = writeLock();
int bucket = signSafeMod(keyHash, capacity);
// Remember where we find the first available spot
int firstDeletedKey = -1;
try {
while (true) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (key == storedKey) {
if (!onlyIfAbsent) {
// Over written an old value for same key
table[bucket + 1] = value;
return storedValue;
} else {
return storedValue;
}
} else if (storedKey == EmptyKey) {
// Found an empty bucket. This means the key is not in the map. If we've already seen a deleted
// key, we should write at that position
if (firstDeletedKey != -1) {
bucket = firstDeletedKey;
} else {
++usedBuckets;
}
if (value == ValueNotFound) {
value = valueProvider.apply(key);
}
table[bucket] = key;
table[bucket + 1] = value;
++size;
return valueProvider != null ? value : ValueNotFound;
} else if (storedKey == DeletedKey) {
// The bucket contained a different deleted key
if (firstDeletedKey == -1) {
firstDeletedKey = bucket;
}
}
bucket = (bucket + 2) & (table.length - 1);
}
} finally {
if (usedBuckets > resizeThreshold) {
try {
rehash();
} finally {
unlockWrite(stamp);
}
} else {
unlockWrite(stamp);
}
}
}
long addAndGet(long key, long delta, int keyHash) {
long stamp = writeLock();
int bucket = signSafeMod(keyHash, capacity);
// Remember where we find the first available spot
int firstDeletedKey = -1;
try {
while (true) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (key == storedKey) {
// Over written an old value for same key
long newValue = storedValue + delta;
checkBiggerEqualZero(newValue);
table[bucket + 1] = newValue;
return newValue;
} else if (storedKey == EmptyKey) {
// Found an empty bucket. This means the key is not in the map. If we've already seen a deleted
// key, we should write at that position
checkBiggerEqualZero(delta);
if (firstDeletedKey != -1) {
bucket = firstDeletedKey;
} else {
++usedBuckets;
}
table[bucket] = key;
table[bucket + 1] = delta;
++size;
return delta;
} else if (storedKey == DeletedKey) {
// The bucket contained a different deleted key
if (firstDeletedKey == -1) {
firstDeletedKey = bucket;
}
}
bucket = (bucket + 2) & (table.length - 1);
}
} finally {
if (usedBuckets > resizeThreshold) {
try {
rehash();
} finally {
unlockWrite(stamp);
}
} else {
unlockWrite(stamp);
}
}
}
boolean compareAndSet(long key, long currentValue, long newValue, int keyHash) {
long stamp = writeLock();
int bucket = signSafeMod(keyHash, capacity);
// Remember where we find the first available spot
int firstDeletedKey = -1;
try {
while (true) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (key == storedKey) {
if (storedValue != currentValue) {
return false;
}
// Over write an old value for same key
table[bucket + 1] = newValue;
return true;
} else if (storedKey == EmptyKey) {
// Found an empty bucket. This means the key is not in the map.
if (currentValue == -1) {
if (firstDeletedKey != -1) {
bucket = firstDeletedKey;
} else {
++usedBuckets;
}
table[bucket] = key;
table[bucket + 1] = newValue;
++size;
return true;
} else {
return false;
}
} else if (storedKey == DeletedKey) {
// The bucket contained a different deleted key
if (firstDeletedKey == -1) {
firstDeletedKey = bucket;
}
}
bucket = (bucket + 2) & (table.length - 1);
}
} finally {
if (usedBuckets > resizeThreshold) {
try {
rehash();
} finally {
unlockWrite(stamp);
}
} else {
unlockWrite(stamp);
}
}
}
private long remove(long key, long value, int keyHash) {
long stamp = writeLock();
int bucket = signSafeMod(keyHash, capacity);
try {
while (true) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (key == storedKey) {
if (value == ValueNotFound || value == storedValue) {
--size;
cleanBucket(bucket);
return storedValue;
} else {
return ValueNotFound;
}
} else if (storedKey == EmptyKey) {
// Key wasn't found
return ValueNotFound;
}
bucket = (bucket + 2) & (table.length - 1);
}
} finally {
unlockWrite(stamp);
}
}
int removeIf(LongPredicate filter) {
long stamp = writeLock();
int removedCount = 0;
try {
// Go through all the buckets for this section
for (int bucket = 0; bucket < table.length; bucket += 2) {
long storedKey = table[bucket];
if (storedKey != DeletedKey && storedKey != EmptyKey) {
if (filter.test(storedKey)) {
// Removing item
--size;
++removedCount;
cleanBucket(bucket);
}
}
}
return removedCount;
} finally {
unlockWrite(stamp);
}
}
int removeIf(LongLongPredicate filter) {
long stamp = writeLock();
int removedCount = 0;
try {
// Go through all the buckets for this section
for (int bucket = 0; bucket < table.length; bucket += 2) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (storedKey != DeletedKey && storedKey != EmptyKey) {
if (filter.test(storedKey, storedValue)) {
// Removing item
--size;
++removedCount;
cleanBucket(bucket);
}
}
}
return removedCount;
} finally {
unlockWrite(stamp);
}
}
private void cleanBucket(int bucket) {
int nextInArray = (bucket + 2) & (table.length - 1);
if (table[nextInArray] == EmptyKey) {
table[bucket] = EmptyKey;
table[bucket + 1] = ValueNotFound;
--usedBuckets;
} else {
table[bucket] = DeletedKey;
table[bucket + 1] = ValueNotFound;
}
}
void clear() {
long stamp = writeLock();
try {
Arrays.fill(table, EmptyKey);
this.size = 0;
this.usedBuckets = 0;
} finally {
unlockWrite(stamp);
}
}
public void forEach(BiConsumerLong processor) {
long stamp = tryOptimisticRead();
long[] table = this.table;
boolean acquiredReadLock = false;
try {
// Validate no rehashing
if (!validate(stamp)) {
// Fallback to read lock
stamp = readLock();
acquiredReadLock = true;
table = this.table;
}
// Go through all the buckets for this section
for (int bucket = 0; bucket < table.length; bucket += 2) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];
if (!acquiredReadLock && !validate(stamp)) {
// Fallback to acquiring read lock
stamp = readLock();
acquiredReadLock = true;
storedKey = table[bucket];
storedValue = table[bucket + 1];
}
if (storedKey != DeletedKey && storedKey != EmptyKey) {
processor.accept(storedKey, storedValue);
}
}
} finally {
if (acquiredReadLock) {
unlockRead(stamp);
}
}
}
private void rehash() {
// Expand the hashmap
int newCapacity = capacity * 2;
long[] newTable = new long[2 * newCapacity];
Arrays.fill(newTable, EmptyKey);
// Re-hash table
for (int i = 0; i < table.length; i += 2) {
long storedKey = table[i];
long storedValue = table[i + 1];
if (storedKey != EmptyKey && storedKey != DeletedKey) {
insertKeyValueNoLock(newTable, newCapacity, storedKey, storedValue);
}
}
capacity = newCapacity;
table = newTable;
usedBuckets = size;
resizeThreshold = (int) (capacity * MapFillFactor);
}
private static void insertKeyValueNoLock(long[] table, int capacity, long key, long value) {
int bucket = signSafeMod(hash(key), capacity);
while (true) {
long storedKey = table[bucket];
if (storedKey == EmptyKey) {
// The bucket is empty, so we can use it
table[bucket] = key;
table[bucket + 1] = value;
return;
}
bucket = (bucket + 2) & (table.length - 1);
}
}
}
private static final long HashMixer = 0xc6a4a7935bd1e995l;
private static final int R = 47;
final static long hash(long key) {
long hash = key * HashMixer;
hash ^= hash >>> R;
hash *= HashMixer;
return hash;
}
static final int signSafeMod(long n, int Max) {
return (int) (n & (Max - 1)) << 1;
}
private static final int alignToPowerOfTwo(int n) {
return (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(n - 1));
}
private static final void checkBiggerEqualZero(long n) {
if (n < 0L) {
throw new IllegalArgumentException("Keys and values must be >= 0");
}
}
}