package edu.berkeley.nlp.util;
import edu.berkeley.nlp.math.SloppyMath;
import java.util.*;
/**
* @author Dan Klein
*/
public class Counters {
public static <E> Counter<E> normalize(Counter<E> counter) {
Counter<E> normalizedCounter = new Counter<E>();
double total = counter.totalCount();
for (E key : counter.keySet()) {
normalizedCounter.setCount(key, counter.getCount(key) / total);
}
return normalizedCounter;
}
public static <T> Counter<T> counterFromCollection(Iterable<T> iterable) {
Counter<T> counts = new Counter<T>();
for (T t : iterable) {
counts.incrementCount(t,1.0);
}
return counts;
}
public static<E,C extends Iterable<?>> Counter<E> counterFromData(Iterable<? extends Collection<E>> iterable) {
Counter<E> counts = new Counter<E>();
for (Collection<E> coll: iterable) {
counts.incrementAll(coll, 1.0);
}
return counts;
}
public static <K,V> CounterMap<K,V> conditionalNormalize(CounterMap<K,V> counterMap) {
CounterMap<K,V> normalizedCounterMap = new CounterMap<K,V>();
for (K key : counterMap.keySet()) {
Counter<V> normalizedSubCounter = normalize(counterMap.getCounter(key));
for (V value : normalizedSubCounter.keySet()) {
double count = normalizedSubCounter.getCount(value);
normalizedCounterMap.setCount(key, value, count);
}
}
return normalizedCounterMap;
}
public static <K> double l2Norm(Counter<K> counts) {
double sum = 0.0;
for (Map.Entry<K, Double> entry : counts.getEntrySet()) {
double count = entry.getValue();
sum += count * count;
}
return Math.sqrt(sum);
}
public static <K> Counter<K> l2Normalize(Counter<K> counts) {
Counter<K> normalizedCounts = new Counter<K>();
double norm = 0.0;
for (Map.Entry<K, Double> entry : counts.getEntrySet()) {
double count = entry.getValue();
norm += count * count;
}
norm = Math.sqrt(norm);
if (norm == 0.0) {
return normalizedCounts;
}
for (K key: counts.keySet()) {
double count = counts.getCount(key);
normalizedCounts.setCount(key, count/norm);
}
return normalizedCounts;
}
public static <L> List<L> sortedKeys(final Counter<L> counts) {
List<L> keys = new ArrayList<L>();
keys.addAll(counts.keySet());
Collections.sort(keys, new Comparator<L>() {
public int compare(L arg0, L arg1) {
double diff = counts.getCount(arg1) - counts.getCount(arg0);
if (diff < 0) { return -1; }
if (diff == 0.0) { return 0; }
return 1;
}
});
return keys;
}
public static <K> Counter<K> exponentiate(Counter<K> counts) {
Counter<K> exponentiated = new Counter<K>();
for (Map.Entry<K, Double> entry : counts.entrySet()) {
exponentiated.setCount(entry.getKey(),Math.exp(entry.getValue()));
}
return exponentiated;
}
public static <K> void exponentiateInPlace(Counter<K> counts) {
for (Map.Entry<K, Double> entry : counts.entrySet()) {
entry.setValue(Math.exp(entry.getValue()));
}
}
public static <K> void logInPlace(Counter<K> counts)
{
for (Map.Entry<K, Double> entry : counts.entrySet()) {
entry.setValue(Math.log(entry.getValue()));
}
}
/**
*
* @param logScores
* @param <K>
*/
public static <K> void makeProbsFromLogScoresInPlace(Counter<K> logScores)
{
double logSum = SloppyMath.logAdd(logScores);
for (Map.Entry<K, Double> entry : logScores.entrySet()) {
double logScore = entry.getValue();
double prob = Math.exp(logScore-logSum);
entry.setValue(prob);
}
logScores.setDirty(true);
}
}