package org.jcommons.type;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import org.jcommons.common.MapFactory;
/**
* Maintains counts of (key, value) pairs. The map is structured so that for
* every key, one can getFromOrigin a counter over values. Example usage: keys
* might be words with values being POS tags, and the count being the number of
* occurences of that word/tag pair. The sub-counters returned by
* getCounter(word) would be count distributions over tags for that word.
*
* from berkeley
*/
public class CounterMap<K, V> implements java.io.Serializable {
private static final long serialVersionUID = 1L;
MapFactory<V, Double> mf;
Map<K, Counter<V>> counterMap;
double defltVal = 0.0;
public interface CountFunction<V> {
double count(V v1, V v2);
}
/**
* Build a counter map by iterating pairwise over the list. This assumes
* that the given pair wise items are the same symmetrically. (The relation
* at i and i + 1 are the same) It creates a counter map such that the pairs
* are: count(v1,v2) and count(v2,v1) are the same
*
* @param items
* the items to iterate over
* @param countFunction
* the function to count
* @param <V>
* the type to count
* @return the counter map pairwise
*/
public static <V> CounterMap<V, V> runPairWise(final List<V> items,
final CountFunction<V> countFunction) {
ExecutorService exec = new ThreadPoolExecutor(Runtime.getRuntime()
.availableProcessors(), Runtime.getRuntime()
.availableProcessors(), 0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
new RejectedExecutionHandler() {
@Override
public void rejectedExecution(Runnable r,
ThreadPoolExecutor executor) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
executor.submit(r);
}
});
final AtomicInteger begin = new AtomicInteger(0);
final AtomicInteger end = new AtomicInteger(items.size() - 1);
List<Future<V>> futures = new ArrayList<>();
final CounterMap<V, V> count = parallelCounterMap();
for (int i = 0; i < items.size() / 2; i++) {
futures.add(exec.submit(new Callable<V>() {
@Override
public V call() throws Exception {
int begin2 = begin.incrementAndGet();
int end2 = end.decrementAndGet();
V v = items.get(begin2);
V v2 = items.get(end2);
// don't double count
if (count.getCount(v, v2) > 0)
return v;
double cost = countFunction.count(v, v2);
count.incrementCount(v, v2, cost);
count.incrementCount(v2, v, cost);
return v;
}
}));
}
int futureCount = 0;
for (Future<V> future : futures) {
try {
future.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
exec.shutdown();
try {
exec.awaitTermination(1, TimeUnit.MINUTES);
} catch (InterruptedException e) {
e.printStackTrace();
}
return count;
}
/**
* Returns a thread safe counter map
*
* @return
*/
public static <K, V> CounterMap<K, V> parallelCounterMap() {
MapFactory<K, Double> factory = new MapFactory<K, Double>() {
private static final long serialVersionUID = 5447027920163740307L;
@Override
public Map<K, Double> buildMap() {
return new java.util.concurrent.ConcurrentHashMap<>();
}
};
CounterMap<K, V> totalWords = new CounterMap(factory, factory);
return totalWords;
}
protected Counter<V> ensureCounter(K key) {
Counter<V> valueCounter = counterMap.get(key);
if (valueCounter == null) {
valueCounter = buildCounter(mf);
valueCounter.setDeflt(defltVal);
counterMap.put(key, valueCounter);
}
return valueCounter;
}
public Collection<Counter<V>> getCounters() {
return counterMap.values();
}
/**
* @return
*/
protected Counter<V> buildCounter(MapFactory<V, Double> mf) {
return new Counter<V>(mf);
}
/**
* Returns the keys that have been inserted into this CounterMap.
*/
public Set<K> keySet() {
return counterMap.keySet();
}
/**
* Sets the count for a particular (key, value) pair.
*/
public void setCount(K key, V value, double count) {
Counter<V> valueCounter = ensureCounter(key);
valueCounter.setCount(value, count);
}
// public void setCount(Pair<K,V> pair) {
//
// }
/**
* Increments the count for a particular (key, value) pair.
*/
public void incrementCount(K key, V value, double count) {
Counter<V> valueCounter = ensureCounter(key);
valueCounter.incrementCount(value, count);
}
/**
* Gets the count of the given (key, value) entry, or zero if that entry is
* not present. Does not createComplex any objects.
*/
public double getCount(K key, V value) {
Counter<V> valueCounter = counterMap.get(key);
if (valueCounter == null)
return defltVal;
return valueCounter.getCount(value);
}
/**
* Gets the sub-counter for the given key. If there is none, a counter is
* created for that key, and installed in the CounterMap. You can, for
* example, add to the returned empty counter directly (though you
* shouldn't). This is so whether the key is present or not, modifying the
* returned counter has the same effect (but don't do it).
*/
public Counter<V> getCounter(K key) {
return ensureCounter(key);
}
public void incrementAll(Map<K, V> map, double count) {
for (Map.Entry<K, V> entry : map.entrySet()) {
incrementCount(entry.getKey(), entry.getValue(), count);
}
}
public void incrementAll(CounterMap<K, V> cMap) {
for (Map.Entry<K, Counter<V>> entry : cMap.counterMap.entrySet()) {
K key = entry.getKey();
Counter<V> innerCounter = entry.getValue();
for (Map.Entry<V, Double> innerEntry : innerCounter.entrySet()) {
V value = innerEntry.getKey();
incrementCount(key, value, innerEntry.getValue());
}
}
}
/**
* Gets the total count of the given key, or zero if that key is not
* present. Does not createComplex any objects.
*/
public double getCount(K key) {
Counter<V> valueCounter = counterMap.get(key);
if (valueCounter == null)
return 0.0;
return valueCounter.totalCount();
}
/**
* Returns the total of all counts in sub-counters. This implementation is
* linear; it recalculates the total each time.
*/
public double totalCount() {
double total = 0.0;
for (Map.Entry<K, Counter<V>> entry : counterMap.entrySet()) {
Counter<V> counter = entry.getValue();
total += counter.totalCount();
}
return total;
}
/**
* Returns the total number of (key, value) entries in the CounterMap (not
* their total counts).
*/
public int totalSize() {
int total = 0;
for (Map.Entry<K, Counter<V>> entry : counterMap.entrySet()) {
Counter<V> counter = entry.getValue();
total += counter.size();
}
return total;
}
/**
* The number of keys in this CounterMap (not the number of key-value
* entries -- use totalSize() for that)
*/
public int size() {
return counterMap.size();
}
/**
* True if there are no entries in the CounterMap (false does not mean
* totalCount > 0)
*/
public boolean isEmpty() {
return size() == 0;
}
/**
* Finds the key with maximum count. This is a linear operation, and ties
* are broken arbitrarily.
*
* @return a key with minumum count
*/
public Pair<K, V> argMax() {
double maxCount = Double.NEGATIVE_INFINITY;
Pair<K, V> maxKey = null;
for (Map.Entry<K, Counter<V>> entry : counterMap.entrySet()) {
Counter<V> counter = entry.getValue();
V localMax = counter.argMax();
if (counter.getCount(localMax) > maxCount || maxKey == null) {
maxKey = new Pair<K, V>(entry.getKey(), localMax);
maxCount = counter.getCount(localMax);
}
}
return maxKey;
}
public String toString(int maxValsPerKey) {
StringBuilder sb = new StringBuilder("[\n");
for (Map.Entry<K, Counter<V>> entry : counterMap.entrySet()) {
sb.append(" ");
sb.append(entry.getKey());
sb.append(" -> ");
sb.append(entry.getValue().toString(maxValsPerKey));
sb.append("\n");
}
sb.append("]");
return sb.toString();
}
@Override
public String toString() {
return toString(20);
}
public String toString(Collection<String> keyFilter) {
StringBuilder sb = new StringBuilder("[\n");
for (Map.Entry<K, Counter<V>> entry : counterMap.entrySet()) {
if (keyFilter != null && !keyFilter.contains(entry.getKey())) {
continue;
}
sb.append(" ");
sb.append(entry.getKey());
sb.append(" -> ");
sb.append(entry.getValue().toString(20));
sb.append("\n");
}
sb.append("]");
return sb.toString();
}
public CounterMap(CounterMap<K, V> cm) {
this();
incrementAll(cm);
}
public CounterMap() {
this(false);
}
public boolean isEqualTo(CounterMap<K, V> map) {
boolean tmp = true;
CounterMap<K, V> bigger = map.size() > size() ? map : this;
for (K k : bigger.keySet()) {
tmp &= map.getCounter(k).isEqualTo(getCounter(k));
}
return tmp;
}
public CounterMap(MapFactory<K, Counter<V>> outerMF,
MapFactory<V, Double> innerMF) {
mf = innerMF;
counterMap = outerMF.buildMap();
}
public CounterMap(boolean identityHashMap) {
this(
identityHashMap ? new MapFactory.IdentityHashMapFactory<K, Counter<V>>()
: new MapFactory.HashMapFactory<K, Counter<V>>(),
identityHashMap ? new MapFactory.IdentityHashMapFactory<V, Double>()
: new MapFactory.HashMapFactory<V, Double>());
}
public static void main(String[] args) {
CounterMap<String, String> bigramCounterMap = new CounterMap<String, String>();
bigramCounterMap.incrementCount("people", "run", 1);
bigramCounterMap.incrementCount("cats", "growl", 2);
bigramCounterMap.incrementCount("cats", "scamper", 3);
System.out.println(bigramCounterMap);
System.out.println("Entries for cats: "
+ bigramCounterMap.getCounter("cats"));
System.out.println("Entries for dogs: "
+ bigramCounterMap.getCounter("dogs"));
System.out.println("Count of cats scamper: "
+ bigramCounterMap.getCount("cats", "scamper"));
System.out.println("Count of snakes slither: "
+ bigramCounterMap.getCount("snakes", "slither"));
System.out.println("Total size: " + bigramCounterMap.totalSize());
System.out.println("Total count: " + bigramCounterMap.totalCount());
System.out.println(bigramCounterMap);
}
public void normalize() {
for (K key : keySet()) {
getCounter(key).normalize();
}
}
public void normalizeWithDiscount(double discount) {
for (K key : keySet()) {
Counter<V> ctr = getCounter(key);
double totalCount = ctr.totalCount();
for (V value : ctr.keySet()) {
ctr.setCount(value, (ctr.getCount(value) - discount)
/ totalCount);
}
}
}
/**
* Constructs reverse CounterMap where the count of a pair (k,v) is the
* count of (v,k) in the current CounterMap
*
* @return
*/
public CounterMap<V, K> invert() {
CounterMap<V, K> invertCounterMap = new CounterMap<V, K>();
for (K key : this.keySet()) {
Counter<V> keyCounts = this.getCounter(key);
for (V val : keyCounts.keySet()) {
double count = keyCounts.getCount(val);
invertCounterMap.setCount(val, key, count);
}
}
return invertCounterMap;
}
/**
* Scale all entries in <code>CounterMap</code> by <code>scaleFactor</code>
*
* @param scaleFactor
*/
public void scale(double scaleFactor) {
for (K key : keySet()) {
Counter<V> counts = getCounter(key);
counts.scale(scaleFactor);
}
}
public boolean containsKey(K key) {
return counterMap.containsKey(key);
}
public Iterator<Pair<K, V>> getPairIterator() {
class PairIterator implements Iterator<Pair<K, V>> {
Iterator<K> outerIt;
Iterator<V> innerIt;
K curKey;
public PairIterator() {
outerIt = keySet().iterator();
}
private boolean advance() {
if (innerIt == null || !innerIt.hasNext()) {
if (!outerIt.hasNext()) {
return false;
}
curKey = outerIt.next();
innerIt = getCounter(curKey).keySet().iterator();
}
return true;
}
public boolean hasNext() {
return advance();
}
public Pair<K, V> next() {
advance();
assert curKey != null;
return Pair.newPair(curKey, innerIt.next());
}
public void remove() {
// TODO Auto-generated method stub
}
}
return new PairIterator();
}
public Set<Map.Entry<K, Counter<V>>> getEntrySet() {
// TODO Auto-generated method stub
return counterMap.entrySet();
}
public void removeKey(K oldIndex) {
counterMap.remove(oldIndex);
}
public void setCounter(K newIndex, Counter<V> counter) {
counterMap.put(newIndex, counter);
}
public void setDefault(double defltVal) {
this.defltVal = defltVal;
for (Counter<V> vCounter : counterMap.values()) {
vCounter.setDeflt(defltVal);
}
}
}