package edu.berkeley.nlp.util;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
//import fig.basic.Pair;
/**
* A countermap that uses fast counters. This should most certainly be rewritten.
*
* John
*/
public class FastCounterMap<K, V> implements java.io.Serializable {
private static final long serialVersionUID = 1L;
Map<K, FastCounter<V>> counterMap;
boolean sortedList;
protected FastCounter<V> ensureCounter(K key) {
FastCounter<V> valueCounter = counterMap.get(key);
if (valueCounter == null) {
valueCounter = new FastCounter<V>();
if (sortedList) valueCounter.switchToSortedList();
counterMap.put(key, valueCounter);
}
return valueCounter;
}
/**
* Returns the keys that have been inserted into this FastCounterMap.
*/
public Set<K> keySet() {
return counterMap.keySet();
}
/**
* Sets the count for a particular (key, value) pair.
*/
public void setCount(K key, V value, double count) {
FastCounter<V> valueCounter = ensureCounter(key);
valueCounter.setCount(value, count);
}
// public void setCount(Pair<K,V> pair) {
//
// }
/**
* Increments the count for a particular (key, value) pair.
*/
public void incrementCount(K key, V value, double count) {
FastCounter<V> valueCounter = ensureCounter(key);
valueCounter.incrementCount(value, count);
}
/**
* Gets the count of the given (key, value) entry, or zero if that entry is
* not present. Does not create any objects.
*/
public double getCount(K key, V value) {
FastCounter<V> valueCounter = counterMap.get(key);
if (valueCounter == null) return 0.0;
return valueCounter.getCount(value);
}
/**
* Gets the sub-counter for the given key. If there is none, a counter is
* created for that key, and installed in the CounterMap. You can, for
* example, add to the returned empty counter directly (though you shouldn't).
* This is so whether the key is present or not, modifying the returned
* counter has the same effect (but don't do it).
*/
public FastCounter<V> getCounter(K key) {
return ensureCounter(key);
}
public void incrementAll(Map<K, V> map, double count) {
for (Map.Entry<K, V> entry : map.entrySet()) {
incrementCount(entry.getKey(), entry.getValue(), count);
}
}
public void incrementAll(FastCounterMap<K, V> cMap) {
for (K key : cMap.keySet()) {
for (V value : cMap.getCounter(key).keySet()) {
incrementCount(key, value, cMap.getCounter(key).getCount(value));
}
}
}
/**
* Gets the total count of the given key, or zero if that key is
* not present. Does not create any objects.
*/
public double getCount(K key) {
FastCounter<V> valueCounter = counterMap.get(key);
if (valueCounter == null) return 0.0;
return valueCounter.totalCount();
}
/**
* Returns the total of all counts in sub-counters. This implementation is
* linear; it recalculates the total each time.
*/
public double totalCount() {
double total = 0.0;
for (Map.Entry<K, FastCounter<V>> entry : counterMap.entrySet()) {
FastCounter<V> counter = entry.getValue();
total += counter.totalCount();
}
return total;
}
/**
* Returns the total number of (key, value) entries in the CounterMap (not
* their total counts).
*/
public int totalSize() {
int total = 0;
for (Map.Entry<K, FastCounter<V>> entry : counterMap.entrySet()) {
FastCounter<V> counter = entry.getValue();
total += counter.size();
}
return total;
}
/**
* The number of keys in this CounterMap (not the number of key-value entries
* -- use totalSize() for that)
*/
public int size() {
return counterMap.size();
}
/**
* True if there are no entries in the CounterMap (false does not mean
* totalCount > 0)
*/
public boolean isEmpty() {
return size() == 0;
}
/**
* Finds the key with maximum count. This is a linear operation, and ties are broken arbitrarily.
*
* @return a key with minumum count
*/
public Pair<K, V> argMax() {
double maxCount = Double.NEGATIVE_INFINITY;
Pair<K, V> maxKey = null;
for (Map.Entry<K, FastCounter<V>> entry : counterMap.entrySet()) {
FastCounter<V> counter = entry.getValue();
V localMax = counter.argMax();
if (counter.getCount(localMax) > maxCount || maxKey == null) {
maxKey = new Pair<K, V>(entry.getKey(), localMax);
maxCount = counter.getCount(localMax);
}
}
return maxKey;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("[\n");
for (Map.Entry<K, FastCounter<V>> entry : counterMap.entrySet()) {
sb.append(" ");
sb.append(entry.getKey());
sb.append(" -> ");
sb.append(entry.getValue().toString(20));
sb.append("\n");
}
sb.append("]");
return sb.toString();
}
public FastCounterMap(FastCounterMap<K, V> cm) {
this();
incrementAll(cm);
}
// public boolean isEqualTo(FastCounterMap<K, V> map) {
// boolean tmp = true;
// FastCounterMap<K, V> bigger = map.size() > size() ? map : this;
// for (K k : bigger.keySet()) {
// tmp &= map.getCounter(k).isEqualTo(getCounter(k));
// }
// return tmp;
// }
public FastCounterMap() {
counterMap = new HashMap<K, FastCounter<V>>();
}
public static void main(String[] args) {
FastCounterMap<String, String> bigramCounterMap = new FastCounterMap<String, String>();
bigramCounterMap.incrementCount("people", "run", 1);
bigramCounterMap.incrementCount("cats", "growl", 2);
bigramCounterMap.incrementCount("cats", "scamper", 3);
System.out.println(bigramCounterMap);
System.out.println("Entries for cats: " + bigramCounterMap.getCounter("cats"));
System.out.println("Entries for dogs: " + bigramCounterMap.getCounter("dogs"));
System.out.println("Count of cats scamper: "
+ bigramCounterMap.getCount("cats", "scamper"));
System.out.println("Count of snakes slither: "
+ bigramCounterMap.getCount("snakes", "slither"));
System.out.println("Total size: " + bigramCounterMap.totalSize());
System.out.println("Total count: " + bigramCounterMap.totalCount());
System.out.println(bigramCounterMap);
}
public void normalize() {
for (K key : keySet()) {
getCounter(key).normalize();
}
}
public void normalizeWithDiscount(double discount) {
for (K key : keySet()) {
FastCounter<V> ctr = getCounter(key);
double totalCount = ctr.totalCount();
for (V value : ctr.keySet()) {
ctr.setCount(value, (ctr.getCount(value) - discount) / totalCount);
}
}
}
/**
* Constructs reverse CounterMap where the count of a pair (k,v)
* is the count of (v,k) in the current CounterMap
* @return
*/
public FastCounterMap<V, K> invert() {
FastCounterMap<V, K> invertCounterMap = new FastCounterMap<V, K>();
for (K key : this.keySet()) {
FastCounter<V> keyCounts = this.getCounter(key);
for (V val : keyCounts.keySet()) {
double count = keyCounts.getCount(val);
invertCounterMap.setCount(val, key, count);
}
}
return invertCounterMap;
}
public boolean containsKey(K key) {
return counterMap.containsKey(key);
}
public Iterator<Pair<K, V>> getPairIterator() {
class PairIterator implements Iterator<Pair<K, V>> {
Iterator<K> outerIt;
Iterator<V> innerIt;
K curKey;
public PairIterator() {
outerIt = keySet().iterator();
}
private boolean advance() {
if (innerIt == null || !innerIt.hasNext()) {
if (!outerIt.hasNext()) {
return false;
}
curKey = outerIt.next();
innerIt = getCounter(curKey).keySet().iterator();
}
return true;
}
public boolean hasNext() {
return advance();
}
public Pair<K, V> next() {
advance();
assert curKey != null;
return Pair.newPair(curKey, innerIt.next());
}
public void remove() {
// TODO Auto-generated method stub
}
};
return new PairIterator();
}
public Set<Map.Entry<K, FastCounter<V>>> getEntrySet() {
// TODO Auto-generated method stub
return counterMap.entrySet();
}
public void removeKey(K oldIndex) {
counterMap.remove(oldIndex);
}
public void setCounter(K newIndex, FastCounter<V> counter) {
counterMap.put(newIndex, counter);
}
public void setSortedList(boolean sortedList) {
this.sortedList = sortedList;
for (Map.Entry<K, FastCounter<V>> entry : getEntrySet()) {
FastCounter<V> ctr = entry.getValue();
if (sortedList) {
ctr.switchToSortedList();
} else {
ctr.switchToHashTable();
}
}
}
}