package edu.berkeley.nlp.lm.collections;
import java.io.Serializable;
import java.util.AbstractCollection;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
/**
* Provides a map from objects to non-negative integers. Motivation: provides a
* specialized data structure for mapping objects to doubles which is both fast
* and space efficient. Feature 1: You can switch between two representations of
* the map: - Sorted list (lookups involve binary search) - Hash table with
* linear probing (lookups involve hashing) Feature 2: Sometimes, we want
* several maps with the same set of keys. If we lock the map, we can share the
* same keys between several maps, which saves space.
* <p>
* Note: in the sorted list, we first sort the keys by hash code, and then for
* equal hash code, we sort by the objects values. We hope that hash code
* collisions will be rare enough that we won't have to resort to comparing
* objects.
* <p>
* Typical usage: - Construct a map using a hash table. - To save space, switch
* to a sorted list representation.
* <p>
* Will get runtime exception if try to used sorted list and keys are not
* comparable.
*
* @author Adam Pauls
* @author Percy Liang
*/
@SuppressWarnings({ "ucd", "rawtypes" })
public class TIntMap<T extends Comparable> extends AbstractTMap<T> implements Iterable<TIntMap<T>.Entry>, Serializable
{
protected static final long serialVersionUID = 42;
private int[] values;
public TIntMap() {
this(AbstractTMap.<T> defaultFunctionality(), defaultExpectedSize);
}
public TIntMap(final Functionality<T> keyFunc) {
this(keyFunc, defaultExpectedSize);
}
public TIntMap(final int expectedSize) {
this(AbstractTMap.<T> defaultFunctionality(), expectedSize);
}
// If keys are locked, we can share the same keys.
public TIntMap(final AbstractTMap<T> map) {
this(map.keyFunc);
this.mapType = map.mapType;
this.locked = map.locked;
this.num = map.num;
this.keys = map.locked ? map.keys : (T[]) map.keys.clone(); // Share keys! CHECKED
if (map instanceof TIntMap)
this.values = ((TIntMap<T>) map).values.clone();
else
this.values = new int[keys.length];
}
/**
* expectedSize: expected number of entries we're going to have in the map.
*/
public TIntMap(final Functionality<T> keyFunc, final int expectedSize) {
this.keyFunc = keyFunc;
this.mapType = MapType.HASH_TABLE;
this.locked = false;
this.num = 0;
allocate(getCapacity(expectedSize, false));
this.numCollisions = 0;
}
// Main operations
public boolean containsKey(final T key) {
return find(key, false) != -1;
}
public int get(final T key, final int defaultValue) {
final int i = findHelper(key, false);
return i == -1 ? defaultValue : values[i];
}
public int getSure(final T key) {
// Throw exception if key doesn't exist.
final int i = find(key, false);
if (i == -1) throw new RuntimeException("Missing key: " + key);
return values[i];
}
public void put(final T key, final int value) {
assert !Double.isNaN(value);
final int i = find(key, true);
keys[i] = key;
values[i] = value;
}
public void put(final T key, final int value, final boolean keepHigher) {
assert !Double.isNaN(value);
final int i = find(key, true);
keys[i] = key;
if (keepHigher && values[i] > value) return;
values[i] = value;
}
public void incr(final T key, final int dValue) {
final int i = find(key, true);
keys[i] = key;
if (Double.isNaN(values[i]))
values[i] = dValue; // New value
else
values[i] += dValue;
}
public void incrIfKeyExists(final T key, final int dValue) {
final int i = find(key, false);
if (i == -1) return;
keys[i] = key;
if (Double.isNaN(values[i]))
values[i] = dValue; // New value
else
values[i] += dValue;
}
public void scale(final T key, final int dValue) {
final int i = find(key, true);
if (i == -1) return;
values[i] *= dValue;
}
public int size() {
return num;
}
public int capacity() {
return keys.length;
}
/*
* public void clear() { // Keep the same capacity num = 0; for(int i = 0; i
* < keys.length; i++) keys[i] = null; }
*/
public void gut() {
values = null;
} // Save memory
// Simple operations on values
// Implement them here for maximum efficiency.
public double sum() {
double sum = 0;
for (int i = 0; i < keys.length; i++)
if (keys[i] != null) sum += values[i];
return sum;
}
public void putAll(final int value) {
for (int i = 0; i < keys.length; i++)
if (keys[i] != null) values[i] = value;
}
public void incrAll(final int dValue) {
for (int i = 0; i < keys.length; i++)
if (keys[i] != null) values[i] += dValue;
}
public void multAll(final int dValue) {
for (int i = 0; i < keys.length; i++)
if (keys[i] != null) values[i] *= dValue;
}
// Return the key with the maximum value
public T argmax() {
int besti = -1;
for (int i = 0; i < keys.length; i++)
if (keys[i] != null && (besti == -1 || values[i] > values[besti])) besti = i;
return besti == -1 ? null : keys[besti];
}
// Return the maximum value
public double max() {
int besti = -1;
for (int i = 0; i < keys.length; i++)
if (keys[i] != null && (besti == -1 || values[i] > values[besti])) besti = i;
return besti == -1 ? Double.NEGATIVE_INFINITY : values[besti];
}
// For each (key, value) in map, increment this's key by factor*value
public void incrMap(final TIntMap<T> map, final int factor) {
for (int i = 0; i < map.keys.length; i++)
if (map.keys[i] != null) incr(map.keys[i], factor * map.values[i]);
}
// If keys are locked, we can share the same keys.
public TIntMap<T> copy() {
final TIntMap<T> newMap = new TIntMap<T>(keyFunc);
newMap.mapType = mapType;
newMap.locked = locked;
newMap.num = num;
newMap.keys = locked ? keys : (T[]) keys.clone(); // Share keys! CHECKED
newMap.values = values.clone();
return newMap;
}
// Return a map with only keys in the set
public TIntMap<T> restrict(final Set<T> set) {
final TIntMap<T> newMap = new TIntMap<T>(keyFunc);
newMap.mapType = mapType;
if (mapType == MapType.SORTED_LIST) {
allocate(getCapacity(num, false));
for (int i = 0; i < keys.length; i++) {
if (set.contains(keys[i])) {
newMap.keys[newMap.num] = keys[i];
newMap.values[newMap.num] = values[i];
newMap.num++;
}
}
} else if (mapType == MapType.HASH_TABLE) {
for (int i = 0; i < keys.length; i++)
if (keys[i] != null && set.contains(keys[i])) newMap.put(keys[i], values[i]);
}
newMap.locked = locked;
return newMap;
}
// For sorting the entries.
// Warning: this class has the overhead of the parent class
private class FullEntry implements Comparable<FullEntry>
{
private FullEntry(final T key, final int value) {
this.key = key;
this.value = value;
}
@Override
@SuppressWarnings({ "unchecked" })
public int compareTo(final FullEntry e) {
// final int h1 = hash(key);
// final int h2 = hash(e.key);
// if (h1 != h2) return h1 - h2;
return key.compareTo(e.key);
}
@Override
public boolean equals(final Object o) {
throw new UnsupportedOperationException();
}
private final T key;
private final int value;
}
// Compare by value.
public class EntryValueComparator implements Comparator<Entry>
{
@Override
public int compare(final Entry e1, final Entry e2) {
return Double.compare(values[e1.i], values[e2.i]);
}
}
public EntryValueComparator entryValueComparator() {
return new EntryValueComparator();
}
// For iterating.
public class Entry
{
private Entry(final int i) {
this.i = i;
}
public T getKey() {
return keys[i];
}
public int getValue() {
return values[i];
}
public void setValue(final int newValue) {
values[i] = newValue;
}
private final int i;
}
public void lock() {
locked = true;
}
public void switchToSortedList() {
switchMapType(MapType.SORTED_LIST);
}
public void switchToHashTable() {
switchMapType(MapType.HASH_TABLE);
}
////////////////////////////////////////////////////////////
public class EntrySet extends AbstractSet<Entry>
{
@Override
public Iterator<Entry> iterator() {
return new EntryIterator();
}
@Override
public int size() {
return num;
}
@Override
public boolean contains(final Object o) {
throw new UnsupportedOperationException();
}
@Override
public boolean remove(final Object o) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
}
public class KeySet extends AbstractSet<T>
{
@Override
public Iterator<T> iterator() {
return new KeyIterator();
}
@Override
public int size() {
return num;
}
@SuppressWarnings("unchecked")
@Override
public boolean contains(final Object o) {
return containsKey((T) o);
} // CHECKED
@Override
public boolean remove(final Object o) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
}
public class ValueCollection extends AbstractCollection<Integer>
{
@Override
public Iterator<Integer> iterator() {
return new ValueIterator();
}
@Override
public int size() {
return num;
}
@Override
public boolean contains(final Object o) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
}
@Override
public EntryIterator iterator() {
return new EntryIterator();
}
public EntrySet entrySet() {
return new EntrySet();
}
public KeySet keySet() {
return new KeySet();
}
public ValueCollection values() {
return new ValueCollection();
}
// WARNING: no checks that this iterator is only used when
// the map is not being structurally changed
private class EntryIterator extends MapIterator<Entry>
{
@Override
public Entry next() {
return new Entry(nextIndex());
}
}
private class KeyIterator extends MapIterator<T>
{
@Override
public T next() {
return keys[nextIndex()];
}
}
private class ValueIterator extends MapIterator<Integer>
{
@Override
public Integer next() {
return values[nextIndex()];
}
}
private abstract class MapIterator<E> implements Iterator<E>
{
public MapIterator() {
if (mapType == MapType.SORTED_LIST)
end = size();
else
end = capacity();
next = -1;
nextIndex();
}
@Override
public boolean hasNext() {
return next < end;
}
int nextIndex() {
final int curr = next;
do {
next++;
} while (next < end && keys[next] == null);
return curr;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
private int next, end;
}
////////////////////////////////////////////////////////////
/**
* How much capacity do we need for this type of map, given that we want n
* elements. compact: whether we want to save space and don't plan on
* growing.
*/
private int getCapacity(final int n, final boolean compact) {
int capacity;
if (mapType == MapType.SORTED_LIST)
capacity = compact ? n : n * growFactor;
else if (mapType == MapType.HASH_TABLE) {
capacity = n * growFactor + 2; // Make sure there's enough room for n+2 more entries
} else
throw new RuntimeException("Internal bug");
return Math.max(capacity, 1);
}
/**
* Convert the map to the given type.
*/
private synchronized void switchMapType(final MapType newMapType) {
assert !locked;
//System.out.println("switchMapType(" + newMapType + ", " + compact + ")");
// Save old keys and values, allocate space
final T[] oldKeys = keys;
final int[] oldValues = values;
mapType = newMapType;
allocate(getCapacity(num, true));
numCollisions = 0;
if (newMapType == MapType.SORTED_LIST) {
// Sort the keys
final List<FullEntry> entries = new ArrayList<FullEntry>(num);
for (int i = 0; i < oldKeys.length; i++)
if (oldKeys[i] != null) entries.add(new FullEntry(oldKeys[i], oldValues[i]));
Collections.sort(entries);
// Populate the sorted list
for (int i = 0; i < num; i++) {
keys[i] = entries.get(i).key;
values[i] = entries.get(i).value;
}
} else if (mapType == MapType.HASH_TABLE) {
// Populate the hash table
num = 0;
for (int i = 0; i < oldKeys.length; i++) {
if (oldKeys[i] != null) put(oldKeys[i], oldValues[i]);
}
}
}
/**
* Return the first index i for which the target key is less than or equal
* to key i (00001111). Should insert target key at position i. If target is
* larger than all of the elements, return size().
*/
@SuppressWarnings({ "unchecked" })
private int binarySearch(final T targetKey) {
// final int targetHash = hash(targetKey);
int l = 0, u = num - 1;
while (l < u) {
//System.out.println(l);
final int m = (l + u) >>> 1;
// final int keyHash = hash(keys[m]);
if (targetKey.compareTo(keys[m]) <= 0)
u = m;
else
l = m + 1;
}
return l;
}
// Modified hash (taken from HashMap.java).
private int hash(final T x) {
int h = x.hashCode();
h += ~(h << 9);
h ^= (h >>> 14);
h += (h << 4);
h ^= (h >>> 10);
if (h < 0) h = -h; // New
return h;
}
/**
* Modify is whether to make room for the new key if it doesn't exist. If a
* new entry is created, the value at that position will be Double.NaN.
* Here's where all the magic happens.
*/
private int find(final T key, final boolean modify) {
if (modify)
synchronized (this) {
return findHelper(key, modify);
}
else
return findHelper(key, modify);
}
/**
* @param key
* @param modify
* @return
*/
private int findHelper(final T key, final boolean modify) {
//System.out.println("find " + key + " " + modify + " " + mapType + " " + capacity());
if (mapType == MapType.SORTED_LIST) {
// Binary search
final int i = binarySearch(key);
if (i < num && keys[i] != null && key.equals(keys[i])) return i;
if (modify) {
if (locked) throw new RuntimeException("Cannot make new entry for " + key + ", because map is locked");
if (num == capacity()) changeSortedListCapacity(getCapacity(num + 1, false));
// Shift everything forward
for (int j = num; j > i; j--) {
keys[j] = keys[j - 1];
values[j] = values[j - 1];
}
num++;
values[i] = -1;
return i;
} else
return -1;
} else if (mapType == MapType.HASH_TABLE) {
final int capacity = capacity();
final int keyHash = hash(key);
int i = keyHash % capacity;
if (i < 0) i = -i; // Arbitrary transformation
// Make sure big enough
if (!locked && modify && (num > loadFactor * capacity || capacity <= num + 1)) {
/*
* if(locked) throw new
* RuntimeException("Cannot make new entry for " + key +
* ", because map is locked");
*/
switchMapType(MapType.HASH_TABLE);
return find(key, modify);
}
//System.out.println("!!! " + keyHash + " " + capacity);
T currKey = null;
int numCollisionsHere = 0;
while ((currKey = keys[i]) != null && !currKey.equals(key)) { // Collision
// Warning: infinite loop if the hash table is full
// (but this shouldn't happen based on the check above)
i++;
numCollisionsHere++;
if (i == capacity) i = 0;
}
numCollisions += numCollisionsHere;
if (keys[i] != null) { // Found
return i;
}
if (modify) { // Not found
num++;
if (num == capacity) throw new RuntimeException("Hash table is full: " + capacity);
values[i] = -1;
return i;
} else
return -1;
} else
throw new RuntimeException("Internal bug: " + mapType);
}
private void allocate(final int n) {
keys = keyFunc.createArray(n);
values = new int[n];
}
// Resize the sorted list to the new capacity.
private void changeSortedListCapacity(final int newCapacity) {
assert mapType == MapType.SORTED_LIST;
assert newCapacity >= num;
final T[] oldKeys = keys;
final int[] oldValues = values;
allocate(newCapacity);
System.arraycopy(oldKeys, 0, keys, 0, num);
System.arraycopy(oldValues, 0, values, 0, num);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("[");
for (final TIntMap<T>.Entry entry : entrySet()) {
sb.append(entry.getKey() + ":" + entry.getValue() + ", ");
}
sb.append("]");
return sb.toString();
}
}