package edu.berkeley.nlp.lm.collections;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import edu.berkeley.nlp.lm.util.MurmurHash;
/**
* Open address hash map with linear probing. Assumes keys are non-negative
* (uses -1 internally for empty key). Returns 0.0 for keys not in the map.
*
* @author adampauls
*
*/
public final class LongToIntHashMap
{
private long[] keys;
private int[] values;
private int size = 0;
private static final int EMPTY_VAL = -1;
private double maxLoadFactor = 0.5;
private boolean sorted = false;
// private int deflt = -1;
public LongToIntHashMap() {
this(5);
}
public void setLoadFactor(double loadFactor) {
this.maxLoadFactor = loadFactor;
ensureCapacity(values.length);
}
public LongToIntHashMap(int initCapacity_) {
int initCapacity = toSize(initCapacity_);
keys = new long[initCapacity];
values = new int[initCapacity];
Arrays.fill(values, EMPTY_VAL);
}
public String toString() {
String s = "[";
for (Entry entry : primitiveEntries()) {
s += s.length() == 1 ? "" : " ";
s += "(" + entry.key + "," + entry.value + ")";
}
s += "]";
return s;
}
public void toSorted() {
sorted = true;
long[] newKeys = new long[size];
int[] newValues = new int[size];
List<Entry> sortedEntries = new ArrayList<Entry>(size);
for (java.util.Map.Entry<Long, Integer> e : entries()) {
sortedEntries.add((Entry) e);
}
Collections.sort(sortedEntries, new Comparator<Entry>()
{
public int compare(Entry o1, Entry o2) {
return Double.compare(o1.key, o2.key);
}
});
int k = 0;
for (Entry e : sortedEntries) {
newKeys[k] = e.getKey();
newValues[k] = e.getValue();
k++;
}
keys = newKeys;
values = newValues;
}
/**
* @param initCapacity_
* @return
*/
private int toSize(int initCapacity_) {
return Math.max(5, (int) (initCapacity_ / maxLoadFactor) + 1);
}
public void put(Long k, int v) {
checkNotImmutable();
if (size / (double) keys.length > maxLoadFactor) {
rehash();
}
putHelp(k, v, keys, values);
}
public void incrementCount(long k, int d) {
checkNotImmutable();
if (d == 0) return;
int pos = find(k, false);
if (pos == EMPTY_VAL || pos == EMPTY_VAL)
put(k, d);
else
values[pos] += d;
}
/**
*
*/
private void checkNotImmutable() {
if (keys == null) throw new RuntimeException("Cannot change wrapped IntCounter");
if (sorted) throw new RuntimeException("Cannot change sorted IntCounter");
}
/**
*
*/
private void rehash() {
final int length = keys.length * 2 + 1;
rehash(length);
}
/**
* @param length
*/
private void rehash(final int length) {
checkNotImmutable();
long[] newKeys = new long[length];
int[] newValues = new int[length];
Arrays.fill(newValues, EMPTY_VAL);
size = 0;
for (int i = 0; i < keys.length; ++i) {
long curr = keys[i];
int val = values[i];
if (val != EMPTY_VAL) {
putHelp(curr, val, newKeys, newValues);
}
}
keys = newKeys;
values = newValues;
}
/**
* @param k
* @param v
*/
private boolean putHelp(long k, int v, long[] keyArray, int[] valueArray) {
checkNotImmutable();
assert v >= 0;
int pos = find(k, true, keyArray, valueArray);
// int pos = getInitialPos(k, keyArray);
// long currKey = keyArray[pos];
// while (currKey != EMPTY_KEY && currKey != k) {
// pos++;
// if (pos == keyArray.length) pos = 0;
// currKey = keyArray[pos];
// }
//
boolean wasEmpty = valueArray[pos] == EMPTY_VAL;
valueArray[pos] = v;
if (wasEmpty) {
size++;
keyArray[pos] = k;
return true;
}
return false;
}
/**
* @param k
* @param keyArray
* @return
*/
private static int getInitialPos(final long k, final int length) {
if (length < 0) return (int) k;
long hash = MurmurHash.hashOneLong(k, 31);
if (hash < 0) hash = -hash;
int pos = (int) (hash % length);
return pos;
}
public int get(long k, int def) {
int pos = find(k, false);
if (pos == EMPTY_VAL) return def;
return values[pos];
}
private int find(long k, boolean returnLastEmpty) {
return find(k, returnLastEmpty, keys, values);
}
/**
* @param k
* @return
*/
private int find(long k, boolean returnLastEmpty, long[] keyArray, int[] valueArray) {
if (keyArray == null) {
return (int) (k < valueArray.length ? k : EMPTY_VAL);
} else if (sorted) {
final int pos = Arrays.binarySearch(keyArray, k);
return pos < 0 ? EMPTY_VAL : pos;
} else {
final int[] localValues = valueArray;
final int length = localValues.length;
int pos = getInitialPos(k, localValues.length);
int currVal = localValues[pos];
long curr = keyArray[pos];
while (currVal != EMPTY_VAL && curr != k) {
pos++;
if (pos == length) pos = 0;
currVal = localValues[pos];
curr = keyArray[pos];
}
return returnLastEmpty ? pos : (currVal == EMPTY_VAL ? EMPTY_VAL : pos);
}
}
// public void setDefault(int d) {
// this.deflt = d;
// }
public boolean isEmpty() {
return size == 0;
}
public class Entry implements Map.Entry<Long, Integer>, Comparable<Entry>
{
private int index;
public long key;
public int value;
public Entry(long key, int value, int index) {
super();
this.key = key;
assert value >= 0;
this.value = value;
this.index = index;
}
public Long getKey() {
return key;
}
public Integer getValue() {
return value;
}
public Integer setValue(Integer value) {
this.value = value;
values[index] = value;
return this.value;
}
@Override
public int compareTo(Entry o) {
// sortable by *value*
return Double.compare(value, o.value);
}
}
private class EntryIterator extends MapIterator<Map.Entry<Long, Integer>>
{
public Entry next() {
final int nextIndex = nextIndex();
return new Entry(keys == null ? nextIndex : keys[nextIndex], values[nextIndex], nextIndex);
}
}
private class KeyIterator extends MapIterator<Long>
{
public Long next() {
final int nextIndex = nextIndex();
return keys == null ? nextIndex : keys[nextIndex];
}
}
private class PrimitiveEntryIterator extends MapIterator<Entry>
{
public Entry next() {
final int nextIndex = nextIndex();
return new Entry(keys == null ? nextIndex : keys[nextIndex], values[nextIndex], nextIndex);
}
}
private abstract class MapIterator<E> implements Iterator<E>
{
public MapIterator() {
end = keys == null ? size : values.length;
next = -1;
nextIndex();
}
public boolean hasNext() {
return end > 0 && next < end;
}
int nextIndex() {
int curr = next;
do {
next++;
} while (next < end && keys != null && values[next] == EMPTY_VAL);
return curr;
}
public void remove() {
throw new UnsupportedOperationException();
}
private int next, end;
}
public Iterable<Map.Entry<Long, Integer>> entries() {
return Iterators.able(new EntryIterator());
}
public void ensureCapacity(int capacity) {
checkNotImmutable();
int newSize = toSize(capacity);
if (newSize > keys.length) {
rehash(newSize);
}
}
public int size() {
return size;
}
public Iterable<Entry> primitiveEntries() {
return new Iterable<Entry>()
{
public Iterator<Entry> iterator() {
return (new PrimitiveEntryIterator());
}
};
}
public Iterable<Long> keySet() {
return Iterators.able(new KeyIterator());
}
public void clear() {
// Arrays.fill(keys, EMPTY_KEY);
Arrays.fill(values, EMPTY_VAL);
size = 0;
}
public List<Entry> getObjectsSortedByValue(boolean descending) {
List<edu.berkeley.nlp.lm.collections.LongToIntHashMap.Entry> l = new ArrayList<edu.berkeley.nlp.lm.collections.LongToIntHashMap.Entry>();
for (final edu.berkeley.nlp.lm.collections.LongToIntHashMap.Entry entry : primitiveEntries()) {
l.add(entry);
}
Collections.sort(l);
if (descending) Collections.reverse(l);
return l;
}
public LongToIntHashMap copy() {
LongToIntHashMap ret = new LongToIntHashMap();
// ret.deflt = deflt;
ret.keys = Arrays.copyOf(keys, keys.length);
ret.values = Arrays.copyOf(values, values.length);
ret.size = size;
ret.sorted = sorted;
ret.maxLoadFactor = maxLoadFactor;
return ret;
}
}