/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.utils.collections;
import hivemall.utils.codec.VariableByteCodec;
import hivemall.utils.codec.ZigZagLEB128Codec;
import hivemall.utils.math.Primes;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import javax.annotation.Nonnull;
/**
* An open-addressing hash table with double hashing
*
* @see http://en.wikipedia.org/wiki/Double_hashing
*/
public class Int2LongOpenHashTable implements Externalizable {
protected static final byte FREE = 0;
protected static final byte FULL = 1;
protected static final byte REMOVED = 2;
public static final int DEFAULT_SIZE = 65536;
public static final float DEFAULT_LOAD_FACTOR = 0.7f;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;
protected final transient float _growFactor;
protected int[] _keys;
protected long[] _values;
protected byte[] _states;
protected int _used;
protected int _threshold;
protected long defaultReturnValue = -1L;
/**
* Constructor for Externalizable. Should not be called otherwise.
*/
public Int2LongOpenHashTable() {// for Externalizable
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
}
public Int2LongOpenHashTable(int size) {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
public Int2LongOpenHashTable(int size, float loadFactor, float growFactor) {
this(size, loadFactor, growFactor, true);
}
protected Int2LongOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
if (size < 1) {
throw new IllegalArgumentException();
}
this._loadFactor = loadFactor;
this._growFactor = growFactor;
int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
this._keys = new int[actualSize];
this._values = new long[actualSize];
this._states = new byte[actualSize];
this._used = 0;
this._threshold = (int) (actualSize * _loadFactor);
}
public Int2LongOpenHashTable(@Nonnull int[] keys, @Nonnull long[] values,
@Nonnull byte[] states, int used) {
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
this._keys = keys;
this._values = values;
this._states = states;
this._used = used;
this._threshold = keys.length;
}
@Nonnull
public static Int2LongOpenHashTable newInstance() {
return new Int2LongOpenHashTable(DEFAULT_SIZE);
}
public void defaultReturnValue(long v) {
this.defaultReturnValue = v;
}
@Nonnull
public int[] getKeys() {
return _keys;
}
@Nonnull
public long[] getValues() {
return _values;
}
@Nonnull
public byte[] getStates() {
return _states;
}
public boolean containsKey(int key) {
return findKey(key) >= 0;
}
/**
* @return -1.f if not found
*/
public long get(int key) {
int i = findKey(key);
if (i < 0) {
return defaultReturnValue;
}
return _values[i];
}
public long put(int key, long value) {
int hash = keyHash(key);
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
boolean expanded = preAddEntry(keyIdx);
if (expanded) {
keyLength = _keys.length;
keyIdx = hash % keyLength;
}
int[] keys = _keys;
long[] values = _values;
byte[] states = _states;
if (states[keyIdx] == FULL) {// double hashing
if (keys[keyIdx] == key) {
long old = values[keyIdx];
values[keyIdx] = value;
return old;
}
// try second hash
int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
keyIdx += keyLength;
}
if (isFree(keyIdx, key)) {
break;
}
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
long old = values[keyIdx];
values[keyIdx] = value;
return old;
}
}
}
keys[keyIdx] = key;
values[keyIdx] = value;
states[keyIdx] = FULL;
++_used;
return defaultReturnValue;
}
/** Return weather the required slot is free for new entry */
protected boolean isFree(int index, int key) {
byte stat = _states[index];
if (stat == FREE) {
return true;
}
if (stat == REMOVED && _keys[index] == key) {
return true;
}
return false;
}
/** @return expanded or not */
protected boolean preAddEntry(int index) {
if ((_used + 1) >= _threshold) {// too filled
int newCapacity = Math.round(_keys.length * _growFactor);
ensureCapacity(newCapacity);
return true;
}
return false;
}
protected int findKey(int key) {
int[] keys = _keys;
byte[] states = _states;
int keyLength = keys.length;
int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
return keyIdx;
}
// try second hash
int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
keyIdx += keyLength;
}
if (isFree(keyIdx, key)) {
return -1;
}
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
return keyIdx;
}
}
}
return -1;
}
public long remove(int key) {
int[] keys = _keys;
long[] values = _values;
byte[] states = _states;
int keyLength = keys.length;
int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
long old = values[keyIdx];
states[keyIdx] = REMOVED;
--_used;
return old;
}
// second hash
int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
keyIdx += keyLength;
}
if (states[keyIdx] == FREE) {
return defaultReturnValue;
}
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
long old = values[keyIdx];
states[keyIdx] = REMOVED;
--_used;
return old;
}
}
}
return defaultReturnValue;
}
public int size() {
return _used;
}
public int capacity() {
return _keys.length;
}
public void clear() {
Arrays.fill(_states, FREE);
this._used = 0;
}
public IMapIterator entries() {
return new MapIterator();
}
@Override
public String toString() {
int len = size() * 10 + 2;
StringBuilder buf = new StringBuilder(len);
buf.append('{');
IMapIterator i = entries();
while (i.next() != -1) {
buf.append(i.getKey());
buf.append('=');
buf.append(i.getValue());
if (i.hasNext()) {
buf.append(',');
}
}
buf.append('}');
return buf.toString();
}
protected void ensureCapacity(int newCapacity) {
int prime = Primes.findLeastPrimeNumber(newCapacity);
rehash(prime);
this._threshold = Math.round(prime * _loadFactor);
}
private void rehash(int newCapacity) {
int oldCapacity = _keys.length;
if (newCapacity <= oldCapacity) {
throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
}
int[] newkeys = new int[newCapacity];
long[] newValues = new long[newCapacity];
byte[] newStates = new byte[newCapacity];
int used = 0;
for (int i = 0; i < oldCapacity; i++) {
if (_states[i] == FULL) {
used++;
int k = _keys[i];
long v = _values[i];
int hash = keyHash(k);
int keyIdx = hash % newCapacity;
if (newStates[keyIdx] == FULL) {// second hashing
int decr = 1 + (hash % (newCapacity - 2));
while (newStates[keyIdx] != FREE) {
keyIdx -= decr;
if (keyIdx < 0) {
keyIdx += newCapacity;
}
}
}
newkeys[keyIdx] = k;
newValues[keyIdx] = v;
newStates[keyIdx] = FULL;
}
}
this._keys = newkeys;
this._values = newValues;
this._states = newStates;
this._used = used;
}
private static int keyHash(int key) {
return key & 0x7fffffff;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(_threshold);
out.writeInt(_used);
final int[] keys = _keys;
final int size = keys.length;
out.writeInt(size);
final byte[] states = _states;
writeStates(states, out);
final long[] values = _values;
for (int i = 0; i < size; i++) {
if (states[i] != FULL) {
continue;
}
ZigZagLEB128Codec.writeSignedInt(keys[i], out);
ZigZagLEB128Codec.writeSignedLong(values[i], out);
}
}
@Nonnull
private static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out)
throws IOException {
// write empty states's indexes differentially
final int size = status.length;
int cardinarity = 0;
for (int i = 0; i < size; i++) {
if (status[i] != FULL) {
cardinarity++;
}
}
out.writeInt(cardinarity);
if (cardinarity == 0) {
return;
}
int prev = 0;
for (int i = 0; i < size; i++) {
if (status[i] != FULL) {
int diff = i - prev;
assert (diff >= 0);
VariableByteCodec.encodeUnsignedInt(diff, out);
prev = i;
}
}
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this._threshold = in.readInt();
this._used = in.readInt();
final int size = in.readInt();
final int[] keys = new int[size];
final long[] values = new long[size];
final byte[] states = new byte[size];
readStates(in, states);
for (int i = 0; i < size; i++) {
if (states[i] != FULL) {
continue;
}
keys[i] = ZigZagLEB128Codec.readSignedInt(in);
values[i] = ZigZagLEB128Codec.readSignedLong(in);
}
this._keys = keys;
this._values = values;
this._states = states;
}
@Nonnull
private static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status)
throws IOException {
// read non-empty states differentially
final int cardinarity = in.readInt();
Arrays.fill(status, IntOpenHashTable.FULL);
int prev = 0;
for (int j = 0; j < cardinarity; j++) {
int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
status[i] = IntOpenHashTable.FREE;
prev = i;
}
}
public interface IMapIterator {
public boolean hasNext();
/**
* @return -1 if not found
*/
public int next();
public int getKey();
public long getValue();
}
private final class MapIterator implements IMapIterator {
int nextEntry;
int lastEntry = -1;
MapIterator() {
this.nextEntry = nextEntry(0);
}
/** find the index of next full entry */
int nextEntry(int index) {
while (index < _keys.length && _states[index] != FULL) {
index++;
}
return index;
}
public boolean hasNext() {
return nextEntry < _keys.length;
}
public int next() {
if (!hasNext()) {
return -1;
}
int curEntry = nextEntry;
this.lastEntry = curEntry;
this.nextEntry = nextEntry(curEntry + 1);
return curEntry;
}
public int getKey() {
if (lastEntry == -1) {
throw new IllegalStateException();
}
return _keys[lastEntry];
}
public long getValue() {
if (lastEntry == -1) {
throw new IllegalStateException();
}
return _values[lastEntry];
}
}
}