/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.runtime.operators.windowing;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.NoSuchElementException;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.util.MathUtils;
/**
* A special Hash Map implementation that can be traversed efficiently in sync with other
* hash maps.
*
* <p>The differences between this hash map and Java's "java.util.HashMap" are:
* <ul>
* <li>A different hashing scheme. This implementation uses extensible hashing, meaning that
* each hash table growth takes one more lower hash code bit into account, and values that where
* formerly in the same bucket will afterwards be in the two adjacent buckets.
* <li>This allows an efficient traversal of multiple hash maps together, even though the maps are
* of different sizes.</li>
* <li>The map offers functions such as "putIfAbsent()" and "putOrAggregate()"
* <li>The map supports no removal/shrinking.
* </ul>
*/
@Internal
public class KeyMap<K, V> implements Iterable<KeyMap.Entry<K, V>> {
/** The minimum table capacity, 64 entries. */
private static final int MIN_CAPACITY = 0x40;
/**
* The maximum possible table capacity, the largest positive power of two in the 32bit signed
* integer value range.
*/
private static final int MAX_CAPACITY = 0x40000000;
/** The number of bits used for table addressing when table is at max capacity. */
private static final int FULL_BIT_RANGE = MathUtils.log2strict(MAX_CAPACITY);
// ------------------------------------------------------------------------
/** The hash index, as an array of entries. */
private Entry<K, V>[] table;
/** The number of bits by which the hash code is shifted right, to find the bucket. */
private int shift;
/** The number of elements in the hash table. */
private int numElements;
/** The number of elements above which the hash table needs to grow. */
private int rehashThreshold;
/** The base-2 logarithm of the table capacity. */
private int log2size;
// ------------------------------------------------------------------------
/**
* Creates a new hash table with the default initial capacity.
*/
public KeyMap() {
this(0);
}
/**
* Creates a new table with a capacity tailored to the given expected number of elements.
*
* @param expectedNumberOfElements The number of elements to tailor the capacity to.
*/
public KeyMap(int expectedNumberOfElements) {
if (expectedNumberOfElements < 0) {
throw new IllegalArgumentException("Invalid capacity: " + expectedNumberOfElements);
}
// round up to the next power or two
// guard against too small capacity and integer overflows
int capacity = Integer.highestOneBit(expectedNumberOfElements) << 1;
capacity = capacity >= 0 ? Math.max(MIN_CAPACITY, capacity) : MAX_CAPACITY;
// this also acts as a sanity check
log2size = MathUtils.log2strict(capacity);
shift = FULL_BIT_RANGE - log2size;
table = allocateTable(capacity);
rehashThreshold = getRehashThreshold(capacity);
}
// ------------------------------------------------------------------------
// Gets and Puts
// ------------------------------------------------------------------------
/**
* Inserts the given value, mapped under the given key. If the table already contains a value for
* the key, the value is replaced and returned. If no value is contained, yet, the function
* returns null.
*
* @param key The key to insert.
* @param value The value to insert.
* @return The previously mapped value for the key, or null, if no value was mapped for the key.
*
* @throws java.lang.NullPointerException Thrown, if the key is null.
*/
public final V put(K key, V value) {
final int hash = hash(key);
final int slot = indexOf (hash);
// search the chain from the slot
for (Entry<K, V> e = table[slot]; e != null; e = e.next) {
Object k;
if (e.hashCode == hash && ((k = e.key) == key || key.equals(k))) {
// found match
V old = e.value;
e.value = value;
return old;
}
}
// no match, insert a new value
insertNewEntry(hash, key, value, slot);
return null;
}
/**
* Inserts a value for the given key, if no value is yet contained for that key. Otherwise,
* returns the value currently contained for the key.
*
* <p>The value that is inserted in case that the key is not contained, yet, is lazily created
* using the given factory.
*
* @param key The key to insert.
* @param factory The factory that produces the value, if no value is contained, yet, for the key.
* @return The value in the map after this operation (either the previously contained value, or the
* newly created value).
*
* @throws java.lang.NullPointerException Thrown, if the key is null.
*/
public final V putIfAbsent(K key, LazyFactory<V> factory) {
final int hash = hash(key);
final int slot = indexOf(hash);
// search the chain from the slot
for (Entry<K, V> entry = table[slot]; entry != null; entry = entry.next) {
if (entry.hashCode == hash && entry.key.equals(key)) {
// found match
return entry.value;
}
}
// no match, insert a new value
V value = factory.create();
insertNewEntry(hash, key, value, slot);
// return the created value
return value;
}
/**
* Inserts or aggregates a value into the hash map. If the hash map does not yet contain the key,
* this method inserts the value. If the table already contains the key (and a value) this
* method will use the given ReduceFunction function to combine the existing value and the
* given value to a new value, and store that value for the key.
*
* @param key The key to map the value.
* @param value The new value to insert, or aggregate with the existing value.
* @param aggregator The aggregator to use if a value is already contained.
*
* @return The value in the map after this operation: Either the given value, or the aggregated value.
*
* @throws java.lang.NullPointerException Thrown, if the key is null.
* @throws Exception The method forwards exceptions from the aggregation function.
*/
public final V putOrAggregate(K key, V value, ReduceFunction<V> aggregator) throws Exception {
final int hash = hash(key);
final int slot = indexOf(hash);
// search the chain from the slot
for (Entry<K, V> entry = table[slot]; entry != null; entry = entry.next) {
if (entry.hashCode == hash && entry.key.equals(key)) {
// found match
entry.value = aggregator.reduce(entry.value, value);
return entry.value;
}
}
// no match, insert a new value
insertNewEntry(hash, key, value, slot);
// return the original value
return value;
}
/**
* Looks up the value mapped under the given key. Returns null if no value is mapped under this key.
*
* @param key The key to look up.
* @return The value associated with the key, or null, if no value is found for the key.
*
* @throws java.lang.NullPointerException Thrown, if the key is null.
*/
public V get(K key) {
final int hash = hash(key);
final int slot = indexOf(hash);
// search the chain from the slot
for (Entry<K, V> entry = table[slot]; entry != null; entry = entry.next) {
if (entry.hashCode == hash && entry.key.equals(key)) {
return entry.value;
}
}
// not found
return null;
}
private void insertNewEntry(int hashCode, K key, V value, int position) {
Entry<K, V> e = table[position];
table[position] = new Entry<>(key, value, hashCode, e);
numElements++;
// rehash if necessary
if (numElements > rehashThreshold) {
growTable();
}
}
private int indexOf(int hashCode) {
return (hashCode >> shift) & (table.length - 1);
}
/**
* Creates an iterator over the entries of this map.
*
* @return An iterator over the entries of this map.
*/
@Override
public Iterator<Entry<K, V>> iterator() {
return new Iterator<Entry<K, V>>() {
private final Entry<K, V>[] tab = KeyMap.this.table;
private Entry<K, V> nextEntry;
private int nextPos = 0;
@Override
public boolean hasNext() {
if (nextEntry != null) {
return true;
}
else {
while (nextPos < tab.length) {
Entry<K, V> e = tab[nextPos++];
if (e != null) {
nextEntry = e;
return true;
}
}
return false;
}
}
@Override
public Entry<K, V> next() {
if (nextEntry != null || hasNext()) {
Entry<K, V> e = nextEntry;
nextEntry = nextEntry.next;
return e;
}
else {
throw new NoSuchElementException();
}
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
// ------------------------------------------------------------------------
// Properties
// ------------------------------------------------------------------------
/**
* Gets the number of elements currently in the map.
* @return The number of elements currently in the map.
*/
public int size() {
return numElements;
}
/**
* Checks whether the map is empty.
* @return True, if the map is empty, false otherwise.
*/
public boolean isEmpty() {
return numElements == 0;
}
/**
* Gets the current table capacity, i.e., the number of slots in the hash table, without
* and overflow chaining.
* @return The number of slots in the hash table.
*/
public int getCurrentTableCapacity() {
return table.length;
}
/**
* Gets the base-2 logarithm of the hash table capacity, as returned by
* {@link #getCurrentTableCapacity()}.
*
* @return The base-2 logarithm of the hash table capacity.
*/
public int getLog2TableCapacity() {
return log2size;
}
public int getRehashThreshold() {
return rehashThreshold;
}
public int getShift() {
return shift;
}
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
@SuppressWarnings("unchecked")
private Entry<K, V>[] allocateTable(int numElements) {
return (Entry<K, V>[]) new Entry<?, ?>[numElements];
}
private void growTable() {
final int newSize = table.length << 1;
// only grow if there is still space to grow the table
if (newSize > 0) {
final Entry<K, V>[] oldTable = this.table;
final Entry<K, V>[] newTable = allocateTable(newSize);
final int newShift = shift - 1;
final int newMask = newSize - 1;
// go over all slots from the table. since we hash to adjacent positions in
// the new hash table, this is actually cache efficient
for (Entry<K, V> entry : oldTable) {
// traverse the chain for each slot
while (entry != null) {
final int newPos = (entry.hashCode >> newShift) & newMask;
Entry<K, V> nextEntry = entry.next;
entry.next = newTable[newPos];
newTable[newPos] = entry;
entry = nextEntry;
}
}
this.table = newTable;
this.shift = newShift;
this.rehashThreshold = getRehashThreshold(newSize);
this.log2size += 1;
}
}
private static int hash(Object key) {
int code = key.hashCode();
// we need a strong hash function that generates diverse upper bits
// this hash function is more expensive than the "scramble" used by "java.util.HashMap",
// but required for this sort of hash table
code = (code + 0x7ed55d16) + (code << 12);
code = (code ^ 0xc761c23c) ^ (code >>> 19);
code = (code + 0x165667b1) + (code << 5);
code = (code + 0xd3a2646c) ^ (code << 9);
code = (code + 0xfd7046c5) + (code << 3);
return (code ^ 0xb55a4f09) ^ (code >>> 16);
}
private static int getRehashThreshold(int capacity) {
// divide before multiply, to avoid overflow
return capacity / 4 * 3;
}
// ------------------------------------------------------------------------
// Testing Utilities
// ------------------------------------------------------------------------
/**
* For testing only: Actively counts the number of entries, rather than using the
* counter variable. This method has linear complexity, rather than constant.
*
* @return The counted number of entries.
*/
int traverseAndCountElements() {
int num = 0;
for (Entry<?, ?> entry : table) {
while (entry != null) {
num++;
entry = entry.next;
}
}
return num;
}
/**
* For testing only: Gets the length of the longest overflow chain.
* This method has linear complexity.
*
* @return The length of the longest overflow chain.
*/
int getLongestChainLength() {
int maxLen = 0;
for (Entry<?, ?> entry : table) {
int thisLen = 0;
while (entry != null) {
thisLen++;
entry = entry.next;
}
maxLen = Math.max(maxLen, thisLen);
}
return maxLen;
}
// ------------------------------------------------------------------------
/**
* An entry in the hash table.
*
* @param <K> Type of the key.
* @param <V> Type of the value.
*/
public static final class Entry<K, V> {
final K key;
final int hashCode;
V value;
Entry<K, V> next;
long touchedTag;
Entry(K key, V value, int hashCode, Entry<K, V> next) {
this.key = key;
this.value = value;
this.next = next;
this.hashCode = hashCode;
}
public K getKey() {
return key;
}
public V getValue() {
return value;
}
}
// ------------------------------------------------------------------------
/**
* Performs a traversal about logical the multi-map that results from the union of the
* given maps. This method does not actually build a union of the map, but traverses the hash maps
* together.
*
* @param maps The array uf maps whose union should be traversed.
* @param visitor The visitor that is called for each key and all values.
* @param touchedTag A tag that is used to mark elements that have been touched in this specific
* traversal. Each successive traversal should supply a larger value for this
* tag than the previous one.
*
* @param <K> The type of the map's key.
* @param <V> The type of the map's value.
*/
public static <K, V> void traverseMaps(
final KeyMap<K, V>[] maps,
final TraversalEvaluator<K, V> visitor,
final long touchedTag)
throws Exception {
// we need to work on the maps in descending size
Arrays.sort(maps, CapacityDescendingComparator.INSTANCE);
final int[] shifts = new int[maps.length];
final int[] lowBitsMask = new int[maps.length];
final int numSlots = maps[0].table.length;
final int numTables = maps.length;
// figure out how much each hash table collapses the entries
for (int i = 0; i < numTables; i++) {
shifts[i] = maps[0].log2size - maps[i].log2size;
lowBitsMask[i] = (1 << shifts[i]) - 1;
}
// go over all slots (based on the largest hash table)
for (int pos = 0; pos < numSlots; pos++) {
// for each slot, go over all tables, until the table does not have that slot any more
// for tables where multiple slots collapse into one, we visit that one when we process the
// latest of all slots that collapse to that one
int mask;
for (int rootTable = 0;
rootTable < numTables && ((mask = lowBitsMask[rootTable]) & pos) == mask;
rootTable++) {
// use that table to gather keys and start collecting keys from the following tables
// go over all entries of that slot in the table
Entry<K, V> entry = maps[rootTable].table[pos >> shifts[rootTable]];
while (entry != null) {
// take only entries that have not been collected as part of other tables
if (entry.touchedTag < touchedTag) {
entry.touchedTag = touchedTag;
final K key = entry.key;
final int hashCode = entry.hashCode;
visitor.startNewKey(key);
visitor.nextValue(entry.value);
addEntriesFromChain(entry.next, visitor, key, touchedTag, hashCode);
// go over the other hash tables and collect their entries for the key
for (int followupTable = rootTable + 1; followupTable < numTables; followupTable++) {
Entry<K, V> followupEntry = maps[followupTable].table[pos >> shifts[followupTable]];
if (followupEntry != null) {
addEntriesFromChain(followupEntry, visitor, key, touchedTag, hashCode);
}
}
visitor.keyDone();
}
entry = entry.next;
}
}
}
}
private static <K, V> void addEntriesFromChain(
Entry<K, V> entry,
TraversalEvaluator<K, V> visitor,
K key,
long touchedTag,
int hashCode) throws Exception {
while (entry != null) {
if (entry.touchedTag < touchedTag && entry.hashCode == hashCode && entry.key.equals(key)) {
entry.touchedTag = touchedTag;
visitor.nextValue(entry.value);
}
entry = entry.next;
}
}
// ------------------------------------------------------------------------
/**
* Comparator that defines a descending order on maps depending on their table capacity
* and number of elements.
*/
static final class CapacityDescendingComparator implements Comparator<KeyMap<?, ?>> {
static final CapacityDescendingComparator INSTANCE = new CapacityDescendingComparator();
private CapacityDescendingComparator() {}
@Override
public int compare(KeyMap<?, ?> o1, KeyMap<?, ?> o2) {
// this sorts descending
int cmp = o2.getLog2TableCapacity() - o1.getLog2TableCapacity();
if (cmp != 0) {
return cmp;
}
else {
return o2.size() - o1.size();
}
}
}
// ------------------------------------------------------------------------
/**
* A factory for lazy/on-demand instantiation of values.
*
* @param <V> The type created by the factory.
*/
public interface LazyFactory<V> {
/**
* The factory method; creates the value.
* @return The value.
*/
V create();
}
// ------------------------------------------------------------------------
/**
* A visitor for a traversal over the union of multiple hash maps. The visitor is
* called for each key in the union of the maps and all values associated with that key
* (one per map, but multiple across maps).
*
* @param <K> The type of the key.
* @param <V> The type of the value.
*/
public interface TraversalEvaluator<K, V> {
/**
* Called whenever the traversal starts with a new key.
*
* @param key The key traversed.
* @throws Exception Method forwards all exceptions.
*/
void startNewKey(K key) throws Exception;
/**
* Called for each value found for the current key.
*
* @param value The next value.
* @throws Exception Method forwards all exceptions.
*/
void nextValue(V value) throws Exception;
/**
* Called when the traversal for the current key is complete.
*
* @throws Exception Method forwards all exceptions.
*/
void keyDone() throws Exception;
}
}